<a href="https://colab.research.google.com/github/smBello-tse/TCN/blob/main/Vanilla_TCN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
from math import floor
from tqdm import tqdm

In [2]:
with open("shakespeare.txt", "r") as f:
    text = f.read()
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [3]:
vocab = list(set(text))
vocab.sort()
char2id = {vocab[i]: i for i in range(len(vocab))}
id2char = {i: vocab[i] for i in range(len(vocab))}

In [4]:
def encode(text):
    return [char2id[ch] for ch in text]

def decode(ids):
    if type(ids) == int: ids = [ids]
    return "".join([id2char[id] for id in ids])

print(decode(encode("Richard")))

Richard


In [5]:
class TextDataset(torch.utils.data.Dataset):
    def __init__(self, text, context_size):
        super(TextDataset, self).__init__()
        self.text = text
        self.context_size = context_size

    def __len__(self):
        if len(self.text) % self.context_size == 0: return floor(len(self.text) / self.context_size) * self.context_size
        else: return floor(len(self.text) / self.context_size) * self.context_size + len(self.text) % self.context_size - 1

    def __getitem__(self, idx):
        offset = idx % self.context_size
        start = idx - offset
        window = encode(self.text[start:start + offset + 2])
        while len(window) < self.context_size + 1:
            window = encode(".") + window
        return torch.tensor(window[:-1], dtype=torch.long), torch.tensor(window[-1], dtype=torch.long)

In [6]:
dts = TextDataset("Bobby is not here today", 11)
print(len(dts))
for i in range(len(dts)):
    x, y = dts[i]
    print(decode(x.tolist()), "--->", decode(y.tolist()))

22
..........B ---> o
.........Bo ---> b
........Bob ---> b
.......Bobb ---> y
......Bobby --->  
.....Bobby  ---> i
....Bobby i ---> s
...Bobby is --->  
..Bobby is  ---> n
.Bobby is n ---> o
Bobby is no ---> t
..........t --->  
.........t  ---> h
........t h ---> e
.......t he ---> r
......t her ---> e
.....t here --->  
....t here  ---> t
...t here t ---> o
..t here to ---> d
.t here tod ---> a
t here toda ---> y


In [55]:
context_size = 32 # Increased from 8
vocab_size = len(vocab)
train_dataset = TextDataset(text, context_size)
x, y = train_dataset[6]
print(decode(x.tolist()))
batch_size = 512
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)

.........................First C


In [46]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(CausalConv1d, self).__init__()
        self.x_causal_pad = kernel_size - 1
        self.causalConv1d = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size)

    def forward(self, x):
        x_pad = nn.functional.pad(x, (self.x_causal_pad, 0), "constant", 0.)
        return self.causalConv1d(x_pad)


x = torch.tensor([[[3., 2., 4., -1., 6.]]])
print(x.shape)
conv = CausalConv1d(1, 1, 3)
print(conv.causalConv1d.weight)
print(conv(x))
print(x[0, 0, 0] * conv.causalConv1d.weight[0, 0, -1])
print(conv.causalConv1d.bias)

torch.Size([1, 1, 5])
Parameter containing:
tensor([[[ 0.4616,  0.2003, -0.0968]]], requires_grad=True)
tensor([[[-0.1900,  0.5078,  1.4987,  1.9216,  1.1658]]],
       grad_fn=<ConvolutionBackward0>)
tensor(-0.2903, grad_fn=<MulBackward0>)
Parameter containing:
tensor([0.1003], requires_grad=True)


In [47]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([nn.Sequential(CausalConv1d(cin, cout, k),
                                                   nn.MaxPool1d(kernel_size=2),
                                                   nn.ReLU())
                                     for cin, cout, k in zip(in_channels, out_channels, kernel_sizes)])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [48]:
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_sizes):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'),
                                                   CausalConv1d(cin, cout, k),
                                                   nn.ReLU())
                                     for cin, cout, k in zip(in_channels, out_channels, kernel_sizes)])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [62]:
class TCN(nn.Module):
    def __init__(self, vocab_size, d_emb, init_len, final_len, in_channels, out_channels, kernel_sizes, fc_dims=None):
        super(TCN, self).__init__()
        self.init_len = init_len
        self.embedding = nn.Embedding(vocab_size, d_emb)
        self.encoder = Encoder(in_channels, out_channels, kernel_sizes)
        #self.decoder = Decoder(out_channels, in_channels, reversed(kernel_sizes))
        self.fcs = None if fc_dims is None else nn.ModuleList([nn.Sequential(nn.Linear(in_features=fc_dim_in, out_features=fc_dim_out), nn.LeakyReLU()) for fc_dim_in, fc_dim_out in zip(fc_dims[:-1], fc_dims[1:])])
        self.head = nn.Linear(in_features=final_len, out_features=vocab_size) if fc_dims is None else nn.Linear(in_features=fc_dims[-1], out_features=vocab_size)

    def forward(self, x):
        N, _ = x.shape
        x = self.embedding(x)
        x = self.encoder(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1).reshape(N, -1)
        if self.fcs is None: return self.head(x)
        else:
          for fc in self.fcs:
            x = fc(x)
          return self.head(x)

    def sampling(self, start_chs, seq_len=100, device="cpu"):
        # Encode starting character
        seq = encode(start_chs)
        if len(seq) < self.init_len:
            nb_start_dots = 0
            while len(seq) < self.init_len:
                seq = encode(".") + seq
                nb_start_dots += 1
        start_idx = 0
        for t in range(seq_len):
            x = torch.tensor(seq[start_idx:], device=device, dtype=torch.long).unsqueeze(0)
            #print(x)

            # Forward step
            logits = self.forward(x)
            #print(logits.shape)
            probs = nn.Softmax(dim=-1)(logits).squeeze(0)

            # Sample next token
            next_ch = torch.multinomial(probs, num_samples=1).item()
            seq.append(next_ch)
            start_idx += 1

        return decode(seq[nb_start_dots:])


In [63]:
d_emb = 32 # Increased from 64
nb_conv_layers = 4
in_channels = [32, 64, 128, 256]
out_channels = [64, 128, 256, 512]
kernel_sizes = [5, 5, 3, 3]
final_len = context_size
for _ in range(len(out_channels)):
    final_len = final_len // 2
final_len *= out_channels[-1]
fc_dims = [final_len, 512, 256, 128]
model = TCN(vocab_size, d_emb, context_size, final_len, in_channels, out_channels, kernel_sizes, fc_dims)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
#print("Model created")

TCN(
  (embedding): Embedding(65, 32)
  (encoder): Encoder(
    (layers): ModuleList(
      (0): Sequential(
        (0): CausalConv1d(
          (causalConv1d): Conv1d(32, 64, kernel_size=(5,), stride=(1,))
        )
        (1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (2): ReLU()
      )
      (1): Sequential(
        (0): CausalConv1d(
          (causalConv1d): Conv1d(64, 128, kernel_size=(5,), stride=(1,))
        )
        (1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (2): ReLU()
      )
      (2): Sequential(
        (0): CausalConv1d(
          (causalConv1d): Conv1d(128, 256, kernel_size=(3,), stride=(1,))
        )
        (1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (2): ReLU()
      )
      (3): Sequential(
        (0): CausalConv1d(
          (causalConv1d): Conv1d(256, 512, kernel_size=(3,), stride=(1,))
        )
        (1): MaxPool1d(kernel_siz

In [64]:
print(model.sampling("Richard", 10, device))

RichardRGVGFdfQqt


In [65]:
lr = 0.001
optimizer = optim.AdamW(params=model.parameters(), lr=lr, weight_decay=0.01)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99)

In [66]:
p = 0
min_loss = 10000
epochs = 50  # Increased from 20 to 50
sampling_freq = epochs // 10
loss_fn = nn.CrossEntropyLoss()
for epoch in range(epochs):
    losses = []
    for idx, (x, y) in enumerate(tqdm(train_loader)):
        model.train()
        x = x.to(device)
        y = y.to(device)
        y_logits = model(x)
        optimizer.zero_grad()
        #print(y.shape)
        loss = loss_fn(y_logits.view(-1, vocab_size), y.view(-1))
        loss.backward()
        losses.append(loss.item())
        if loss < min_loss:
            min_loss = loss
            best_model = model
            best_epoch = epoch
    print(f"Epoch {epoch}, loss= {np.mean(np.array(losses))}")

    if epoch % sampling_freq == 0:
        model.eval()
        print(f"Epoch {epoch}")
        print("Sampling...")
        print(model.sampling("A", 100, device))
    optimizer.step()
    scheduler.step()

100%|██████████| 2179/2179 [00:47<00:00, 45.49it/s]


Epoch 0, loss= 4.187439672987414
Epoch 0
Sampling...
AdD-Pw-gcNemSlPDVQZ&QH3bXyBmRSPKD.C&Nblaqn$nsdVf:mtLiljVCkhWtfMxQBXmSej:YRq&WHEl;&FgxSPsADW.N-zrLcWkf


100%|██████████| 2179/2179 [00:46<00:00, 47.03it/s]


Epoch 1, loss= 4.1744874242578645


100%|██████████| 2179/2179 [00:45<00:00, 47.60it/s]


Epoch 2, loss= 4.147103744890232


100%|██████████| 2179/2179 [00:45<00:00, 47.86it/s]


Epoch 3, loss= 4.068156126826769


100%|██████████| 2179/2179 [00:46<00:00, 47.30it/s]


Epoch 4, loss= 3.8725483088167287


100%|██████████| 2179/2179 [00:45<00:00, 47.57it/s]


Epoch 5, loss= 3.824574768898847
Epoch 5
Sampling...
Aoh J n  DhMnal  a n p NO tjet ke eK sT- dvn fn H kld  vU  h oana  aph aO h e p,h Nib   n  nunw 3p wn


100%|██████████| 2179/2179 [00:45<00:00, 47.51it/s]


Epoch 6, loss= 3.65852904702721


100%|██████████| 2179/2179 [00:46<00:00, 47.08it/s]


Epoch 7, loss= 3.629333794363161


100%|██████████| 2179/2179 [00:46<00:00, 47.04it/s]


Epoch 8, loss= 3.573397373157447


100%|██████████| 2179/2179 [00:46<00:00, 47.26it/s]


Epoch 9, loss= 3.504501010366075


100%|██████████| 2179/2179 [00:45<00:00, 47.37it/s]


Epoch 10, loss= 3.520072160482735
Epoch 10
Sampling...
AZe  te tJtfa    tOth ieBehersoeep 3eejel uhiL
Dir eiI ei phn qej,Dhhbhtt  hh n oU  
vTseejeHta Kx eo


100%|██████████| 2179/2179 [00:46<00:00, 46.83it/s]


Epoch 11, loss= 3.476538759242947


 24%|██▍       | 521/2179 [00:11<00:35, 47.08it/s]


KeyboardInterrupt: 