In [1]:
%pip install gensim
#%pip install torch
%pip install Pillow
%pip install requests
%pip install open_clip_torch
#%pip install torchvision
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

Collecting gensim
  Using cached gensim-4.3.3-cp310-cp310-win_amd64.whl.metadata (8.2 kB)
Collecting numpy<2.0,>=1.18.5 (from gensim)
  Using cached numpy-1.26.4-cp310-cp310-win_amd64.whl.metadata (61 kB)
Collecting scipy<1.14.0,>=1.7.0 (from gensim)
  Using cached scipy-1.13.1-cp310-cp310-win_amd64.whl.metadata (60 kB)
Collecting smart-open>=1.8.1 (from gensim)
  Using cached smart_open-7.1.0-py3-none-any.whl.metadata (24 kB)
Collecting wrapt (from smart-open>=1.8.1->gensim)
  Using cached wrapt-1.17.2-cp310-cp310-win_amd64.whl.metadata (6.5 kB)
Using cached gensim-4.3.3-cp310-cp310-win_amd64.whl (24.0 MB)
Using cached numpy-1.26.4-cp310-cp310-win_amd64.whl (15.8 MB)
Using cached scipy-1.13.1-cp310-cp310-win_amd64.whl (46.2 MB)
Using cached smart_open-7.1.0-py3-none-any.whl (61 kB)
Using cached wrapt-1.17.2-cp310-cp310-win_amd64.whl (38 kB)
Installing collected packages: wrapt, numpy, smart-open, scipy, gensim
Successfully installed gensim-4.3.3 numpy-1.26.4 scipy-1.13.1 smart-open-7.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gensim.downloader as api # pip install gensim
import open_clip
import os
from txt2png import txt_to_png
if torch.backends.mps.is_available():
    from open_clip_test_mps import clip_loss, device, preprocess
elif torch.cuda.is_available():
    from open_clip_test_cuda import clip_loss, device, preprocess
    print("using cuda")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
using cuda


In [5]:
embedding_dim = 25
x_dim = 25
y_dim = 8
output_size = y_dim * x_dim
ascii_chars = " .,:;+*#@$%&0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-_=|/\\()[]{}"
num_chars = len(ascii_chars)
char_to_index = {char: i for i, char in enumerate(ascii_chars)}
index_to_char = {i: char for i, char in enumerate(ascii_chars)}

class ASCIIArtGenerator(nn.Module):
    def __init__(self, input_size, output_size, num_chars):
        super(ASCIIArtGenerator, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, output_size * num_chars)
        self.num_chars = num_chars
        self.output_size = output_size
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x.view(-1, self.output_size, self.num_chars)

def generate_ascii_art(model, embedding, temperature=1.0):
    model.eval()
    device = next(model.parameters()).device  # Get the device of the model
    with torch.no_grad():
        # Move the embedding to the model's device
        embedding = embedding.to(device)
        output = model(embedding.unsqueeze(0))
        
        # Apply temperature scaling
        scaled_logits = output / temperature
        
        # Convert to probabilities
        probs = F.softmax(scaled_logits, dim=2)
        
        # Sample from the distribution
        indices = torch.multinomial(probs.view(-1, num_chars), 1).squeeze(-1)
        
        ascii_grid = ""
        for i in range(y_dim):
            for j in range(x_dim):
                ascii_grid += index_to_char[indices[i * x_dim + j].item()]
            ascii_grid += "\n"
        return ascii_grid

In [8]:
# Create a class to handle the dataset
class ASCIIArtDataset(torch.utils.data.Dataset):
    def __init__(self, data_pairs, word_model):
        """
        data_pairs: List of tuples (word, label_text, ascii_art)
                    label_text: the label text like "cat", "dog"
                    ascii_art: the target ASCII art as a string
        word_model: Word embedding model
        """
        self.data_pairs = data_pairs
        self.word_model = word_model
        
    def __len__(self):
        return len(self.data_pairs)
    
    def __getitem__(self, idx):
        label_text, ascii_art = self.data_pairs[idx]
        
        # Get word embedding
        word_embedding = torch.tensor(self.word_model[label_text], dtype=torch.float32)
        
        # Convert ASCII art to target indices
        target_indices = []
        for char in ascii_art:
            if char == '\n':
                continue
            target_indices.append(char_to_index.get(char, 0))  # Default to 0 if char not found
            
        # Pad or truncate to output_size
        if len(target_indices) < output_size:
            target_indices += [0] * (output_size - len(target_indices))
        else:
            target_indices = target_indices[:output_size]
            
        target_tensor = torch.tensor(target_indices, dtype=torch.long)
        
        return word_embedding, target_tensor, label_text

# Function to train the model
def train_model(model, dataset, candidate_labels, num_epochs, batch_size=1, learning_rate=0.001):
    # Load best so far weights
    try:
        model.load_state_dict(torch.load("bsf_weights.pth"))
    except FileNotFoundError:
        print("No weights to load")
    model.to(device)
    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Define loss function for character prediction
    criterion = nn.CrossEntropyLoss()
    
    # CLIP model for semantic alignment
    clip_model, preprocess = open_clip.create_model_from_pretrained('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
    tokenizer = open_clip.get_tokenizer('hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K')
    clip_model.eval()
    
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    best_loss = float("inf")
    for epoch in range(num_epochs):
        total_loss = 0
        
        for word_emb, target_indices, label_text in dataloader:
            word_emb = word_emb.to(device)
            target_indices = target_indices.to(device)
            
            # Forward pass
            outputs = model(word_emb)
            
            # Reshape outputs for loss calculation
            outputs_flat = outputs.view(-1, num_chars)
            targets_flat = target_indices.view(-1)
            
            # Calculate character prediction loss
            char_loss = criterion(outputs_flat, targets_flat)
            
            # Generate ASCII art and convert to image
            ascii_art = generate_ascii_art(model, word_emb[0].cpu())
            ascii_png = txt_to_png(ascii_art)
            
            # Calculate semantic alignment loss using CLIP
            image_tensor = preprocess(ascii_png).unsqueeze(0).to(device)
            
            # Get the label for CLIP loss
            #candidate_labels = ["a bird", "a dog", "a cat", "a castle"] # Customize based on your classes
            true_label_index = candidate_labels.index(label_text[0])
            
            # Calculate CLIP loss
            clip_loss_val = clip_loss(image_tensor, candidate_labels, true_label_index)
            #TODO: consider a dynamic weights option where is changes after some epochs
            # Combine losses - you can adjust the weights
            loss = char_loss * 0.8 + clip_loss_val * 0.2
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Print epoch stats
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

        # Saving the weights if better
        if avg_loss < best_loss:
            best_loss = avg_loss
            if epoch >= 10:
                torch.save(model.state_dict(), "bsf_weights.pth")
                print("saved best model")
        
        # Print example generation
        if (epoch + 1) % 5 == 0 or epoch == 0:
            print(f"Example generation for {candidate_labels[true_label_index]}:")
            print(ascii_art)
    
    return model

# Example of how to create and use the dataset
def prepare_training_data():
    # This is where you'd load your labeled data
    # Format: [(word, label, ascii_art), ...]
    def read_file_to_string(file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
        return content
    data_pairs = list()
    label_list = list()
    txt_dir = os.path.join("..", "Data", "Text")
    for txt_file in os.listdir(txt_dir):
        if txt_file.endswith(".txt"):
            #print(txt_file)
            label = txt_file.split('_')[0]
            if label in ["bird", "horse", "frog", "fish", "dolphin", "dog", "cat"]:
                ascii_art_txt = read_file_to_string(os.path.join(txt_dir, txt_file))
                label_list.append(label)
                data_pairs.append(tuple([label, ascii_art_txt]))
    
    return data_pairs, list(set(label_list))



Here is the training loop

In [None]:
# Main execution
# Load word embeddings
import gensim.downloader as api
word_model_name = "glove-twitter-25"
word_model = api.load(word_model_name)
embedding_dim = 25

# Define dimensions
x_dim = 25
y_dim = 8
output_size = y_dim * x_dim

# Define character set
ascii_chars = " .,:;+*#@$%&0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-_=|/\\()[]{}"
num_chars = len(ascii_chars)
char_to_index = {char: i for i, char in enumerate(ascii_chars)}
index_to_char = {i: char for i, char in enumerate(ascii_chars)}

# Initialize model
model = ASCIIArtGenerator(embedding_dim, output_size, num_chars)
print("model initialized")



# Prepare dataset
data_pairs, label_list = prepare_training_data()
dataset = ASCIIArtDataset(data_pairs, word_model)
print("dataset prepared")

# Train model
trained_model = train_model(model, dataset, label_list, num_epochs=50, learning_rate=0.001)


model initialized
dataset prepared
No weights to load
Epoch 1/50, Loss: 1.4535
Example generation for cat:
 \  _   ({ (__\\_\|I   j 
   |(  . _       )  /  \/
] _/ -.# - _ ;        _/-
  _      ( x  __         
     A    /      +       
       0      F          
                         
            s            

Epoch 2/50, Loss: 1.0061
Epoch 3/50, Loss: 0.9715
Epoch 4/50, Loss: 0.9527
Epoch 5/50, Loss: 0.9593
Example generation for horse:
, ,_/ (   _    (-(__  ). 
 .  -      _  \ @)_*   _ 
\/ ,Y        _  ) |_  |} 
   { _       , +         
          -       -      
                         
                         
                         

Epoch 6/50, Loss: 0.9494
Epoch 7/50, Loss: 0.9349
Epoch 8/50, Loss: 0.9300
Epoch 9/50, Loss: 0.9170
Epoch 10/50, Loss: 0.9004
Example generation for cat:
  \ //     / \(\/,  _  { 
 . ||      | _          \
  /  \  \/_ -  _     _/  
\ ;      (          _ )  
                 _    _  
              ,          
                         
         

In [None]:

# Test with some words
test_words = ["cat", "dog", "wolf", "eagle"]
for word in test_words:
    if word in word_model:
        embedding = torch.tensor(word_model[word])
        ascii_art = generate_ascii_art(trained_model, embedding, temperature=0.8)
        print(f"ASCII art for '{word}':")
        print(ascii_art)
        print()