In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader

### Create dataset

In [9]:
data_path = 'ascii_data/cleaned_ascii_art_data.json'
# the path leads to a json file that was saved from a pandas series
df = pd.read_json(data_path, typ='series')

In [14]:
print(df[0])

                                                            
                                                            
                                                            
                                                            
                                                            
                                                            
                    ,__    _,            ___                
                   '.`\ /`|     _.-"```   `'.               
                    ; |  /   .'             `}              
                    _\|\/_.-'                 }             
               _.-"a                 {        }             
            .-`  __    /._          {         }\            
           '--"`  `""`   `\   ;    {         } \            
                          |   } __ _\       }\  \           
                         |  /;`   / :.   }`  \  \           
                        | | | .-' /  / /     '. '._         
             jgs    .'__

In [15]:
class AsciiArtDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [16]:
ascii_data = AsciiArtDataset(df)
dataloader = DataLoader(ascii_data, batch_size=4, shuffle=True)

In [19]:
# Get a batch of data
dataiter = next(iter(dataloader))
print(dataiter[0])

                                                            
                                                            
                                                            
                                                            
                     ___________________                    
                     | _______________ |                    
                     | |XXXXXXXXXXXXX| |                    
                     | |XXXXXXXXXXXXX| |                    
                     | |XXXXXXXXXXXXX| |                    
                     | |XXXXXXXXXXXXX| |                    
                     | |XXXXXXXXXXXXX| |                    
                     |_________________|                    
                           _[_______]_                      
                     ___[___________]___                    
                  |         [_____] []|__                   
                 |         [_____] []|  \__                 
             L__________

In [22]:
UNIQUE_CHARS = set()
for i in range(len(df)):
    for char in df[i]:
        UNIQUE_CHARS.add(char)

UNIQUE_CHARS = list(UNIQUE_CHARS)
len(UNIQUE_CHARS)

111

In [48]:
class StableDiffusionModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dropout):
        super(StableDiffusionModel, self).__init__()
        self.encoder = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dropout=dropout)
        self.decoder = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt):
        src = self.encoder(src)
        tgt = self.encoder(tgt)
        output = self.transformer(src, tgt)
        output = self.decoder(output)
        return output

In [52]:
def train(model, dataloader, optimizer, criterion, device, UNIQUE_CHARS):
    model.train()
    total_loss = 0

    for batch in dataloader:
        # Convert characters to indices
        batch_indices = [[UNIQUE_CHARS.index(char) for char in sample] for sample in batch]
        batch_indices = torch.tensor(batch_indices).to(device)

        optimizer.zero_grad()

        # Obscure the input image
        mask = torch.rand(batch_indices.shape, device=device) < 0.15
        random_indices = torch.randint(len(UNIQUE_CHARS), size=batch_indices.shape, device=device)
        obscured_batch = torch.where(mask, random_indices, batch_indices)

        # Forward pass
        output = model(obscured_batch, batch_indices)

        # Calculate loss
        loss = criterion(output.view(-1, len(UNIQUE_CHARS)), batch_indices.view(-1))
        total_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    return total_loss / len(dataloader)

In [54]:
# Hyperparameters
vocab_size = len(UNIQUE_CHARS)
d_model = 256
nhead = 8
num_layers = 6
dropout = 0.1
batch_size = 32
num_epochs = 50
learning_rate = 0.0001

# Create the dataset and dataloader
dataset = AsciiArtDataset(df)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Create the model, optimizer, and loss function
device = torch.device("mps")
model = StableDiffusionModel(vocab_size, d_model, nhead, num_layers, dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(num_epochs):
    train_loss = train(model, dataloader, optimizer, criterion, device, UNIQUE_CHARS)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}")



Epoch [1/50], Train Loss: 0.6602


KeyboardInterrupt: 