In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
from tokenizer import MostFrequentWordsTokenizer, BytePairEncodingTokenizer
from model import FeedForward, MultiHeadAttention
from trainer import TrainerGPT

In [2]:
VOCAB_SIZE = 2000
D_MODEL = 384
N_HEADS = 6
N_LAYERS = 2
DROPOUT = 0.2
MAX_ITERS = 2000
EVAL_INTERVAL = 20
BATCH_SIZE = 128
BLOCK_SIZE = 256
LEARNING_RATE = 3e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
df = pd.read_csv('data/anamnese.csv')
text_corpus = df['transcription'].dropna().astype(str).tolist()
text = " ".join(text_corpus)
df['medical_specialty'].value_counts()


medical_specialty
 Surgery                          1103
 Consult - History and Phy.        516
 Cardiovascular / Pulmonary        372
 Orthopedic                        355
 Radiology                         273
 General Medicine                  259
 Gastroenterology                  230
 Neurology                         223
 SOAP / Chart / Progress Notes     166
 Obstetrics / Gynecology           160
 Urology                           158
 Discharge Summary                 108
 ENT - Otolaryngology               98
 Neurosurgery                       94
 Hematology - Oncology              90
 Ophthalmology                      83
 Nephrology                         81
 Emergency Room Reports             75
 Pediatrics - Neonatal              70
 Pain Management                    62
 Psychiatry / Psychology            53
 Office Notes                       51
 Podiatry                           47
 Dermatology                        29
 Cosmetic / Plastic Surgery         27
 Dentis

In [4]:
tokenizer = MostFrequentWordsTokenizer(vocab_size=VOCAB_SIZE)
tokenizer.build_vocab(text_corpus)
print(tokenizer.tokenize(text_corpus[0]))
print(text_corpus[0])

['subjective', 'this', '23yearold', 'white', 'female', 'presents', 'with', 'complaint', 'of', 'allergies', 'she', 'used', 'to', 'have', 'allergies', 'when', 'she', 'lived', 'in', 'seattle', 'but', 'she', 'thinks', 'they', 'are', 'worse', 'here', 'in', 'the', 'past', 'she', 'has', 'tried', 'claritin', 'and', 'zyrtec', 'both', 'worked', 'for', 'short', 'time', 'but', 'then', 'seemed', 'to', 'lose', 'effectiveness', 'she', 'has', 'used', 'allegra', 'also', 'she', 'used', 'that', 'last', 'summer', 'and', 'she', 'began', 'using', 'it', 'again', 'two', 'weeks', 'ago', 'it', 'does', 'not', 'appear', 'to', 'be', 'working', 'very', 'well', 'she', 'has', 'used', 'overthecounter', 'sprays', 'but', 'no', 'prescription', 'nasal', 'sprays', 'she', 'does', 'have', 'asthma', 'but', 'doest', 'not', 'require', 'daily', 'medication', 'for', 'this', 'and', 'does', 'not', 'think', 'it', 'is', 'flaring', 'upmedications', 'her', 'only', 'medication', 'currently', 'is', 'ortho', 'tricyclen', 'and', 'the', 'al

In [5]:
tokenizer = BytePairEncodingTokenizer(vocab_size=VOCAB_SIZE)
tokenizer.build_vocab(text_corpus)
print(tokenizer.tokenize(text_corpus[0]))
print(text_corpus[0])




['SU', 'B', 'J', 'EC', 'TI', 'VE', ':,', 'This', '2', '3', '-', 'year', '-', 'old', 'white', 'female', 'present', 's', 'with', 'complain', 't', 'of', 'allerg', 'ies', '.', 'She', 'used', 'to', 'have', 'allerg', 'ies', 'when', 'she', 'li', 'ved', 'in', 'Se', 'att', 'le', 'but', 'she', 'th', 'in', 'ks', 'they', 'are', 'wor', 'se', 'here', '.', 'In', 'the', 'past', ',', 'she', 'has', 'tr', 'ied', 'C', 'lar', 'it', 'in', ',', 'and', 'Z', 'y', 'r', 'te', 'c', '.', 'Bo', 'th', 'work', 'ed', 'for', 'short', 'time', 'but', 'then', 'se', 'em', 'ed', 'to', 'lo', 'se', 'ef', 'fect', 'iv', 'en', 'ess', '.', 'She', 'has', 'used', 'Al', 'le', 'gr', 'a', 'also', '.', 'She', 'used', 'that', 'last', 'su', 'mm', 'er', 'and', 'she', 'be', 'gan', 'using', 'it', 'again', 'two', 'weeks', 'ago', '.', 'It', 'does', 'not', 'appear', 'to', 'be', 'work', 'ing', 'very', 'well', '.', 'She', 'has', 'used', 'over', '-', 'the', '-', 'coun', 'ter', 'sp', 'ray', 's', 'but', 'no', 'pres', 'cr', 'ip', 'tion', 'nasal',

In [6]:
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]

def get_batch(split):
  data = train_data if split == 'train' else val_data 
  ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
  x = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
  y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
  return x.to(DEVICE), y.to(DEVICE)

In [7]:
data = torch.tensor(tokenizer.encode(text), dtype=torch.long)

def get_dataset(data, train_size=0.9, split='train'):
  n = int(train_size * len(data))
  train_data, val_data = data[:n], data[n:]
  data = train_data if split == 'train' else val_data

  def get_batch():
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

  return get_batch

train_data = get_dataset(data, split='train')
val_data = get_dataset(data, split='val')

x, y = train_data()
x.shape, y.shape

tokenizer.decode(x[0].tolist()), tokenizer.decode(y[0].tolist())

RuntimeError: random_ expects 'from' to be less than 'to', but got from=0 >= to=-256

In [None]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout, ff_dim):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.attention_norm = nn.LayerNorm(d_model)
        self.feed_forward = FeedForward(d_model, ff_dim, dropout)
        self.feed_forward_norm = nn.LayerNorm(d_model)
        self.residual_dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # --- Self-attention sublayer ---
        attention_output = self.self_attention(x, mask)
        x = self.attention_norm(x + self.residual_dropout(attention_output))

        # --- Feed-forward sublayer ---
        feed_forward_output = self.feed_forward(x)
        x = self.feed_forward_norm(x + self.residual_dropout(feed_forward_output))
        return x

In [None]:
class MiniGPT(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model,
        n_heads,
        num_layers,
        block_size,
        dropout=0.1,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size

        # --- Embedding tables ---
        self.output_embedding = nn.Embedding(vocab_size, d_model)
        self.position_encoding = nn.Embedding(block_size, d_model)

        # --- Decoder stack ---
        self.layers = nn.ModuleList(
            [
                TransformerDecoderBlock(
                    d_model=d_model,
                    n_heads=n_heads,
                    dropout=dropout,
                    ff_dim=d_model * 4,
                )
                for _ in range(num_layers)
            ]
        )

        self.final_norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)

        causal_mask = torch.tril(torch.ones(block_size, block_size)).unsqueeze(0)
        self.register_buffer("causal_mask", causal_mask)

    def forward(self, token_ids, targets=None):
        batch_size, sequence_length = token_ids.shape

        token_vectors = self.output_embedding(token_ids)
        position_indices = torch.arange(sequence_length, device=token_ids.device)
        position_vectors = self.position_encoding(position_indices)

        x = token_vectors + position_vectors

        causal_mask = self.causal_mask[:, :sequence_length, :sequence_length]
        for layer in self.layers:
            x = layer(x, causal_mask)

        # --- Project back to token logits ---
        x = self.final_norm(x)
        return self.output_projection(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits = self(idx_cond)
            logits_last_step = logits[:, -1, :]
            probs = F.softmax(logits_last_step, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
model = MiniGPT(VOCAB_SIZE, D_MODEL, N_HEADS, N_LAYERS, BLOCK_SIZE).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
def get_loss(logits, targets):
    return nn.functional.cross_entropy(logits.reshape(-1, VOCAB_SIZE), targets.reshape(-1))

@torch.no_grad()
def get_validation_loss():
    model.eval()
    losses = torch.zeros(EVAL_INTERVAL)
    for k in range(EVAL_INTERVAL):
        X, Y = get_batch('val')
        losses[k] = get_loss(X, Y).item()
    loss = losses.mean()
    model.train()
    return loss

In [None]:
trainer = TrainerGPT()

trainer.fit(model, data, max_iters=MAX_ITERS,
    eval_interval=EVAL_INTERVAL,
    learning_rate=LEARNING_RATE,
    tokenizer=tokenizer,
    block_size=BLOCK_SIZE,
    batch_size=BATCH_SIZE,
    device=DEVICE,
    eval_prompt="The patient has a history of ")


Starting MiniGPT training...
Prompt: The patient has a history of 

--- Generated Text ---
The patient has a history of the medial inter es les ation io De were con bral as PERFORMED; hand ital ro, fever oma foot see ac and art
step 20: train loss 6.3953, val loss 5.9658, val perplexity 389.8454
Prompt: The patient has a history of 

--- Generated Text ---
The patient has a history of appear and second ro it ts, plan was discharged as ment and sub in x of the long free came in the se of
step 40: train loss 5.4721, val loss 5.1372, val perplexity 170.2419
Prompt: The patient has a history of 

--- Generated Text ---
The patient has a history of En any evidence / 60 W headache to the patient is normal lim its findings. The g ain his vis it ory as ound
step 60: train loss 4.7576, val loss 4.6132, val perplexity 100.8030
Prompt: The patient has a history of 

--- Generated Text ---
The patient has a history of cr hel m ye lo o em p ic ul a ur ine man ent and fat ly increased 7, lymph sutu

KeyboardInterrupt: 