In [1]:
1+1

2

In [2]:
import os, re
import torch
import datetime
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import pandas as pd
from sklearn.model_selection import train_test_split

SEED = 42
torch.manual_seed(SEED)
rng = np.random.default_rng(SEED)


In [None]:
with open("../data/raw/AFDBv4_90.128-254.fasta", "r") as file:
    lines = file.readlines()
    sequences = [seq.strip() for seq in lines if not seq.startswith(">")]
    df = pd.DataFrame({'sequence':sequences})
train_data, test_data = train_test_split(df, test_size=0.5, shuffle = True, random_state=42)
test_data, val_data = train_test_split(test_data, test_size=0.5, shuffle = True, random_state=42)
train_data = pd.read_csv('data/train.txt', names=['sequence'])
test_data = pd.read_csv('data/test.txt', names=['sequence'])
val_data = pd.read_csv('data/val.txt', names=['sequence'])

In [3]:
with open('data/val.txt', 'r') as f:
    lines = f.readlines()
    sequences = [seq.strip() for seq in lines if not seq.startswith(">")]
    val_data_df = pd.DataFrame({'sequence':sequences})

In [4]:
def process_file(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()

    output = []

    for line in lines:
        line = line.strip()
        n = len(line)

        idx1, idx2 = torch.randperm(n-2)[:2] + 1
        if idx1 > idx2:
            idx1, idx2 = idx2, idx1

        prefix, middle, suffix = line[:idx1], line[idx1:idx2], line[idx2:]
        
        p = rng.random()
        if p > 0.66: # PSM
            fim_sample = '@' + prefix + '$' + suffix + '#' + middle
        elif p > 0.33: # SPM
            fim_sample = '$' + suffix + '@' + prefix + '#' + middle
        else: # default
            fim_sample = prefix + middle + suffix

        output.append(fim_sample)

    return '.'.join(output)

train_data = process_file('data/train.txt')
val_data = process_file('data/val.txt')

In [5]:
vocab = sorted(list(set("".join(train_data))) + ['0']) #<PRE> = '@', <MID> = '#', <SUF> = '$', <EOS> = '.', <PAD> = '0'

stoi = {c:i for i,c in enumerate(vocab)}
itos = {i:c for i,c in enumerate(vocab)}

encode = lambda s: torch.LongTensor([stoi[c] for c in s])
decode = lambda l: "".join([itos[i] for i in l])
encode("@AC#TG$.")

tensor([ 4,  5,  6,  0, 21, 10,  1,  2])

In [6]:
def get_val_batch(device="cuda", mode='psm', indexes=None):
    data = val_data_df
    data = data.sequence.values
    
    ix = rng.integers(len(data), size=batch_size)
    xs1 = []
    ys1 = []
    xs2 = []
    ys2 = []
    xs3 = []
    ys3 = []

    mask = []
    for i in ix:
        # if mode == 'psm':
        document = data[i][:ctx_size]
        n = len(document)
        idx1, idx2 = torch.randperm(n-2)[:2] + 1
        if idx1 > idx2:
            idx1, idx2 = idx2, idx1

        prefix, middle, suffix = document[:idx1], document[idx1:idx2], document[idx2:]
        fim_sample = '@' + prefix + '$' + suffix + '#' + middle

        sample_x1 = encode(fim_sample[:ctx_size])
        sample_y1 = encode(fim_sample[1:ctx_size+1])

        # if mode == 'pms':
        # document = data[i][:ctx_size]
        # n = len(document)
        # idx1, idx2 = torch.randperm(n-2)[:2] + 1
        # prefix, middle, _ = document[:idx1], document[idx1:idx2], document[idx2:]
        fim_sample = prefix + middle

        sample_x2 = encode(fim_sample[:ctx_size])
        sample_y2 = encode(fim_sample[1:ctx_size+1])

        mask.append(torch.tensor(len(prefix)))

        # elif mode == 'default':
        sample_x3 = encode(data[i][:ctx_size])
        sample_y3 = encode(data[i][1:ctx_size+1])

        sample_x1 = F.pad(sample_x1, (0, max(0, ctx_size - len(sample_x1))), value=stoi['0'])
        sample_y1 = F.pad(sample_y1, (0, max(0, ctx_size - len(sample_y1))), value=stoi['0'])
        
        sample_x2 = F.pad(sample_x2, (0, max(0, ctx_size - len(sample_x2))), value=stoi['0'])
        sample_y2 = F.pad(sample_y2, (0, max(0, ctx_size - len(sample_y2))), value=stoi['0'])
        
        sample_x3 = F.pad(sample_x3, (0, max(0, ctx_size - len(sample_x3))), value=stoi['0'])
        sample_y3 = F.pad(sample_y3, (0, max(0, ctx_size - len(sample_y3))), value=stoi['0'])


        xs1.append(sample_x1)
        ys1.append(sample_y1)
        
        xs2.append(sample_x2)
        ys2.append(sample_y2)

        xs3.append(sample_x3)
        ys3.append(sample_y3)

    x1 = torch.stack(xs1).to(device)
    y1 = torch.stack(ys1).to(device)

    x2 = torch.stack(xs2).to(device)
    y2 = torch.stack(ys2).to(device)

    x3 = torch.stack(xs3).to(device)
    y3 = torch.stack(ys3).to(device)

    mask = torch.stack(mask).to(device)
    return x1, x2, x3, y1, y2, y3, mask


In [7]:
batch_size = 512
ctx_size = 512

def get_batch(split, device="cuda"):
    data = train_data if split == 'train' else val_data
    ix = rng.integers(len(data) - ctx_size, size=batch_size)
    x = torch.stack([encode(data[i:i+ctx_size]) for i in ix])
    y = torch.stack([encode(data[i+1:i+ctx_size+1]) for i in ix])
    return x.to(device), y.to(device)


In [8]:
xb, yb = get_batch('train', 'cpu')
print('inputs:', xb.shape)
# print(xb)
print('targets:', yb.shape)
# print(yb)

# for t in range(ctx_size):
#     context = xb[0,:t+1]
#     target = yb[0,t]
#     print(f"when input is {decode(context.numpy())} the target: {itos[int(target.numpy())]}")


inputs: torch.Size([512, 512])
targets: torch.Size([512, 512])


In [9]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# xb, yb, _ = get_val_batch(device=device, mode='psm')
# print('inputs:', xb.shape)
# # print(xb)
# print('targets:', yb.shape)
# for t in range(ctx_size):
#     context = xb[1,:t+1]
#     target = yb[1,t]
#     print(f"when input is {decode(context.cpu().numpy())} the target: {itos[int(target.cpu().numpy())]}")


In [10]:
class Head(nn.Module):
    def __init__(self, embed_size, head_embed_size, dropout=0):
        super().__init__()
        self.key = nn.Linear(embed_size, head_embed_size, bias=False)
        self.query = nn.Linear(embed_size, head_embed_size, bias=False)
        self.value = nn.Linear(embed_size, head_embed_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("tril", torch.tril(torch.ones((ctx_size, ctx_size))))
    
    def forward(self, inputs):
        _,T,_ = inputs.shape # batch, ctx, vocab_size
        k = self.key(inputs) # batch, ctx, head_size
        q = self.query(inputs) # batch, ctx, head_size
        mask = q @ k.transpose(-2, -1) # batch, ctx, ctx
        mask = mask.masked_fill(self.tril[:T,:T] == 0, float("-inf"))
        mask = F.softmax(mask, dim=-1)
        mask = self.dropout(mask)
        v = self.value(inputs) # batch, ctx, head_size
        out = mask @ v # batch, ctx, head_size
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_size, dropout=0):
        super().__init__()
        head_embed_size = embed_size // num_heads
        self.heads = nn.ModuleList([Head(embed_size, head_embed_size, dropout) for _ in range(num_heads)])
        self.ff = nn.Linear(head_embed_size*num_heads, embed_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, inputs):
        # batch, ctx, embed_size
        inputs = torch.cat([head(inputs) for head in self.heads], dim=-1) # batch, ctx, num_head*head_size --> batch, ctx, embed_size
        inputs = self.ff(inputs) # batch, ctx, embed_size
        inputs = self.dropout(inputs)
        return inputs
    


class Block(nn.Module):
    def __init__(self, num_heads, embed_size, dropout=0):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_size)
        self.mha = MultiHeadAttention(num_heads, embed_size, dropout)
        self.ln2 = nn.LayerNorm(embed_size)
        self.ff = nn.Sequential(
            nn.Linear(embed_size, embed_size),
            nn.ReLU(),
            nn.Linear(embed_size, embed_size),
            nn.Dropout(dropout)
        )
    
    def forward(self, inputs):
        inputs = self.mha(self.ln1(inputs)) + inputs # batch, ctx, embed_size
        inputs = self.ff(self.ln2(inputs)) + inputs  # batch, ctx, embed_size
        return inputs


In [11]:
len(vocab)

25

In [12]:
embed_size = 512
num_heads = 3
num_layers = 3
dropout = 0.0
lr = 3e-5
steps = 2001
evel_interv = 250
generate_interv = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
VOCAB_SIZE = len(vocab)


class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, embed_size)
        self.blocks = nn.Sequential(*[Block(num_heads, embed_size) for _ in range(num_layers)])
        self.ln = nn.LayerNorm(embed_size)
        self.ff = nn.Linear(embed_size, VOCAB_SIZE)
    
    def forward(self, inputs, targets=None):
        # batch, ctx, vocab_size
        logits = self.embedding(inputs) # batch, ctx, embed_size
        logits = self.blocks(logits) # batch, ctx, embed_size
        logits = self.ln(logits) # batch, ctx, embed_size
        logits = self.ff(logits) # batch, ctx, vocab_size

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            lv = logits.view((B*T, C))
            tv = targets.view((B*T,))
            loss = F.cross_entropy(lv, tv)
        
        return logits, loss

    def generate(self, tokens, n=256):
        for _ in range(n):
            logits, loss = self(tokens[:,-ctx_size:])
            logits = logits[:,-1,:]
            logits = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(logits, num_samples=1)
            if decode(next_token[0].cpu().numpy()) == '.':
                break
            tokens = torch.cat((tokens, next_token), dim=1)
        return tokens

    
    def calculate_loss(self, batch_tokens, targets, mode='psm', indexes=None):
        if mode == 'psm':
            mask = torch.zeros_like(batch_tokens).to(device)

            hash_index = torch.tensor(stoi['#']).to(device)
            indices = (batch_tokens == hash_index).nonzero()[:,1] + 1
            for b, idx in enumerate(indices):
                mask[b, idx:] = 1
                
            pad_index = torch.tensor(stoi['0']).to(device)
            mask = (mask & (batch_tokens != pad_index)).to(torch.long)
            
            logits, _ = self(batch_tokens)
            logits = logits[:,-batch_tokens.shape[1] + 1:,:]
            loss = F.cross_entropy(logits.permute(0, 2, 1), targets[:,1:], reduction='none')
            loss = (loss * mask[:,1:]).sum() / mask.sum()

        elif mode == 'pms':
            mask = torch.zeros_like(batch_tokens).to(device)

            for b, idx in enumerate(indexes):
                mask[b, idx:] = 1
            
            pad_index = torch.tensor(stoi['0']).to(device)
            mask = (mask & (batch_tokens != pad_index)).to(torch.long)


            logits, _ = self(batch_tokens)
            logits = logits[:, -batch_tokens.shape[1] + 1:, :]
            loss = F.cross_entropy(logits.permute(0, 2, 1), targets[:, 1:], reduction='none')

            loss = (loss * mask[:, 1:]).sum() / mask.sum()
            
        elif mode == 'default':
            pad_index = torch.tensor(stoi['0']).to(device)
            mask = (batch_tokens != pad_index).to(torch.long)
            logits, _ = self(batch_tokens)
            logits = logits[:,-batch_tokens.shape[1] + 1:,:]
            loss = F.cross_entropy(logits.permute(0, 2, 1), targets[:,1:], reduction='none')
            loss = (loss * mask[:,1:]).sum() / mask.sum()

        return loss


In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [14]:
model = LanguageModel()
model.to(device)
device

'cuda'

In [15]:
model = LanguageModel()
model.to(device)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')
xb, yb = get_batch("train", device)
logits, loss = model(xb, yb)
print("loss:", loss.item())

print(decode(model.generate(torch.ones((1,1), dtype=torch.int64).to(device), 64)[0].cpu().numpy()))


print(f"batch_size={batch_size} ctx_size={ctx_size} embed_size={embed_size} num_heads={num_heads} num_layers={num_layers} dropout={dropout}")

optim = torch.optim.AdamW(model.parameters(), lr=3e-5)

model.train()

for step in range(steps):
    xb, yb = get_batch('train', device)

    if step % 100 == 0:
        model.eval()
        with torch.no_grad():
            splits = {}
            for split in ['train', 'val']:
                losses = torch.zeros(1)
                for k in range(1):
                    X, y = get_batch(split, device)
                    logits, loss = model(X, y)
                    losses[k] = loss.item()
                splits[split] = losses.mean()
                        
            x1, x2, x3, y1, y2, y3, ixs = get_val_batch(device=device)
            loss_1 = model.calculate_loss(x3, y3, mode='default')
            
            loss_2 = model.calculate_loss(x2, y2, mode='pms', indexes=ixs)
            
            loss_3 = model.calculate_loss(x1, y1, mode='default')

            loss_4 = model.calculate_loss(x1, y1, mode='psm')

            print(f"Step {step:4}: train loss {splits['train']:.5f} | val loss {splits['val']:.5f}"
             f" | AR loss {loss_1} | AR middle loss {loss_2} | FIM loss {loss_3} | FIM middle loss {loss_4}")

        model.train()

    _, loss = model(xb, yb)
    optim.zero_grad()
    loss.backward()
    optim.step()

print()
print(loss.item())


4.743705 M parameters
loss: 3.355654239654541
$H#FP$GAHWVDQCNMAMNID
batch_size=512 ctx_size=512 embed_size=512 num_heads=3 num_layers=3 dropout=0.0
Step    0: train loss 3.35601 | val loss 3.35695 | AR loss 3.342895030975342 | AR middle loss 3.3614940643310547 | FIM loss 3.340390682220459 | FIM middle loss 3.3592135906219482
Step  100: train loss 2.90756 | val loss 2.90893 | AR loss 2.899994134902954 | AR middle loss 2.9614362716674805 | FIM loss 2.9317800998687744 | FIM middle loss 2.960427761077881
Step  200: train loss 2.90424 | val loss 2.90392 | AR loss 2.894641876220703 | AR middle loss 2.9637420177459717 | FIM loss 2.9269349575042725 | FIM middle loss 2.9620444774627686
Step  300: train loss 2.90444 | val loss 2.90104 | AR loss 2.8839876651763916 | AR middle loss 2.9621920585632324 | FIM loss 2.916936159133911 | FIM middle loss 2.961125135421753
Step  400: train loss 2.90234 | val loss 2.90045 | AR loss 2.8943653106689453 | AR middle loss 2.9754397869110107 | FIM loss 2.92659926

In [16]:
path = "./gpt_5M.pth"
torch.save(model, path)


In [4]:
!python3 train_ddp.py

cuda:0
Step    0: loss 3.41977
Step  100: loss 2.90874
Step  200: loss 2.90608
Step  300: loss 2.90013
Step  400: loss 2.90002
Step  500: loss 2.89754
Step  600: loss 2.90122
Step  700: loss 2.89556
Step  800: loss 2.89644
Step  900: loss 2.90114
Step 1000: loss 2.89701
Step 1100: loss 2.89349
Step 1200: loss 2.89301
Step 1300: loss 2.89157
Step 1400: loss 2.89195
Step 1500: loss 2.89415
Step 1600: loss 2.89119
Step 1700: loss 2.88981
Step 1800: loss 2.88975
Step 1900: loss 2.89163
Step 2000: loss 2.89196

Final loss: 2.891958713531494
