In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [None]:
from codebook import Codebook

codebook = Codebook("../data/transcription.txt", embedding_dim=768)

In [None]:
# Dataset and DataLoader
class TextDataset(Dataset):
    def __init__(self, sentences, char_to_idx):
        self.sentences = sentences
        self.char_to_idx = char_to_idx
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        encoded = [self.char_to_idx[char] for char in sentence if char in self.char_to_idx]
        inp = [self.char_to_idx["s"]] + encoded
        return inp

# Collate function for padding and batching
def collate_fn(batch):
    min_len = min(len(sample) for sample in batch)  # Find the shortest sequence in the batch
    
    input = []
    target = []
    
    for b in batch:
        inp = b[:min_len]
        input.append(torch.tensor(inp))
        tar = b[1:min_len] + [0]  # Shift the target to the right by one
        target.append(torch.tensor(tar))
    
    return torch.stack(input), torch.stack(target)
  

# dataset = TextDataset(codebook.sentences, codebook.char_to_idx)
# dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

# for i in dataloader:
#     i = i
#     break


In [None]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super().__init__()
        self.dilation = dilation
        self.kernel_size = kernel_size
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 
                               dilation=dilation, padding=0)  # No automatic padding

    def forward(self, x):
        pad_size = (self.kernel_size - 1) * self.dilation
        x = F.pad(x, (pad_size, 0))  # Manual causal padding
        return self.conv(x)  # Now it's properly causal

class CausalConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels,kernel_size):
        super().__init__()
        self.net = nn.Sequential(
            CausalConv1d(in_channels, out_channels, kernel_size=kernel_size, dilation=1),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.2),  # Dropout after activation
            
            CausalConv1d(out_channels, out_channels, kernel_size=kernel_size, dilation=2),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.2),  # Dropout after activation
            
            CausalConv1d(out_channels, out_channels, kernel_size=kernel_size, dilation=4),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.2),  # Dropout after activation
            
        )

    def forward(self, x):
        return x + self.net(x)
    
    
class CausalCNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, kernel_size, vocab_size):
        super().__init__()

        self.input_layer = nn.Conv1d(input_dim, hidden_dim, 1)
        
        layers = []
        for i in range(num_layers):
            layers.append(CausalConvBlock(
                hidden_dim,
                hidden_dim,
                kernel_size,
            ))
            
        self.network = nn.Sequential(*layers)
        self.output_layer = nn.Conv1d(hidden_dim, vocab_size, 1)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.network(x)
        return self.output_layer(x)
    
    def generate_output(self, input_str, codebook, char_to_idx, idx_to_char, max_length=100):
        """
        Generates output sequence based on an input string, iteratively predicting one character at a time.
        
        Args:
        - input_str: The input string to process.
        - char_to_idx: Dictionary mapping characters to indices.
        - idx_to_char: Dictionary mapping indices to characters.
        - max_length: The maximum length of the generated sequence.

        Returns:
        - Generated string.
        """
        # Convert input string to indices
        input_indices = [char_to_idx[char] for char in input_str]  # Assumes all chars are in char_to_idx
        input_tensor = torch.tensor(input_indices).unsqueeze(0).to(device)  # Add batch dimension

        self.eval()  # Set model to evaluation mode
        generated_str = input_str  # Start the output with the input string

        with torch.no_grad():
            while True:
                # Pass through the model
                x = codebook(input_tensor).transpose(1, 2)
                output = self.forward(x)  
                
                # Get the most likely next character
                output_indices = torch.argmax(output, dim=1).squeeze().cpu().numpy()
                next_char_idx = output_indices[-1]  # Last predicted index
                next_char = idx_to_char[next_char_idx]  # Convert index to char

                # Append the next character to the generated string
                generated_str += next_char
                
                # Update input for next iteration: append the predicted character
                input_indices.append(next_char_idx)
                input_tensor = torch.tensor(input_indices).unsqueeze(0).to(device)  # Update input tensor
                
                if next_char == "s" or len(generated_str) >= max_length:
                    break

        return generated_str

    
    
# Example usage
embed_dim = codebook.embeddings.weight.shape[1]
model = CausalCNN(input_dim=embed_dim, hidden_dim=256, num_layers=5, kernel_size=9, vocab_size=len(codebook.char_to_idx))

# Function to calculate total parameters
def calculate_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params / 1e6 # Convert to millions

print(calculate_params(model))

criterion = nn.CrossEntropyLoss()  # Cross-entropy loss with padding ignored

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = model.to(device)
codebook = codebook.to(device)
criterion = criterion.to(device)

In [None]:
dataset = TextDataset(codebook.sentences, codebook.char_to_idx)
dataloader = DataLoader(dataset, batch_size=512, shuffle=False, collate_fn=collate_fn)

optimizer = optim.Adam(model.parameters(), lr=0.0005)

# Training loop
num_epochs = 10000
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for iteration, (inputs, targets) in enumerate(dataloader):
        
        inputs = inputs.to(device)
        targets = targets.to(device)
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        x = codebook(inputs)
        x = x.transpose(1, 2)
        outputs = model(x)

        # Compute the loss
        loss = criterion(outputs, targets)

        # Backward pass
        loss.backward()

        # Update the parameters
        optimizer.step()
        
        torch.cuda.empty_cache()
        running_loss += loss.item()
        if iteration % 10 == 0:
            print(f"Batch {iteration+1}/{len(dataloader)}, Loss: {loss.item():.4f}")
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    
    torch.save(codebook, "codebook.pt")

print("Training finished!")

In [None]:
torch.save(codebook, "codebook.pt")

model.eval()  # Set model to evaluation mode
x = "sCHAPTER ONE MISSUS RAC"
f"--{model.generate_output(x, codebook, codebook.char_to_idx, codebook.idx_to_char, max_length=1000)}--"

In [None]:
codebook.sentences[0]