<a href="https://colab.research.google.com/github/susannatangg/aoc/blob/master/DoodleStories%20Eric%20Copy%203/30%208%3A42pm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

# Create directory to store the dataset (optional)
!mkdir -p datasets

# Download the dataset
!wget http://cvssp.org/data/fscoco/fscoco.tar.gz -P datasets/

# Extract the tar.gz file
!tar -xzf datasets/fscoco.tar.gz -C datasets/

print("Dataset downloaded and extracted successfully!")

--2025-03-31 01:35:10--  http://cvssp.org/data/fscoco/fscoco.tar.gz
Resolving cvssp.org (cvssp.org)... 131.227.95.12
Connecting to cvssp.org (cvssp.org)|131.227.95.12|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cvssp.org/data/fscoco/fscoco.tar.gz [following]
--2025-03-31 01:35:10--  https://cvssp.org/data/fscoco/fscoco.tar.gz
Connecting to cvssp.org (cvssp.org)|131.227.95.12|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2378884718 (2.2G) [application/x-gzip]
Saving to: ‘datasets/fscoco.tar.gz’


2025-03-31 01:37:20 (17.6 MB/s) - ‘datasets/fscoco.tar.gz’ saved [2378884718/2378884718]

Dataset downloaded and extracted successfully!


In [None]:
import os
import cv2
import numpy as np
import shutil
from PIL import Image
from sklearn.model_selection import train_test_split

dataset_path = "datasets/fscoco/"
sketch_dir = os.path.join(dataset_path, "sketchycoco")  # Using sketchycoco instead of raster_sketches
vector_dir = os.path.join(dataset_path, "vector_sketches")
output_dir = os.path.join(dataset_path, "processed_sketches")
output_dir_sub = os.path.join(dataset_path, "processed_sub")

os.makedirs(output_dir, exist_ok=True)

In [None]:
def preprocess_sketch(img_path, size=(224, 224)):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # Read in grayscale
    if img is None:
        print(f"Warning: Unable to read image {img_path}")
        return None  # Return None to prevent errors

    img = cv2.resize(img, size)  # Resize to fixed dimensions
    _, img = cv2.threshold(img, 128, 255, cv2.THRESH_BINARY)  # Binarize the image
    img = img / 255.0  # Normalize pixel values
    return img

# Process sketches
for root, _, files in os.walk(sketch_dir):
    for file in files:
        img_path = os.path.join(root, file)
        processed_img = preprocess_sketch(img_path)

        if processed_img is not None:
            # Preserve relative subdirectory structure
            relative_path = os.path.relpath(img_path, sketch_dir)
            save_path = os.path.join(output_dir, relative_path)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            cv2.imwrite(save_path, (processed_img * 255).astype('uint8'))
        else:
            print(f"Skipping file: {img_path}")

In [None]:
def augment_sketch(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return None

    # Flip image horizontally
    flipped = cv2.flip(img, 1)

    # Rotate image
    (h, w) = img.shape[:2]
    center = (w // 2, h // 2)
    rotation_matrix = cv2.getRotationMatrix2D(center, 15, 1.0)
    rotated = cv2.warpAffine(img, rotation_matrix, (w, h))

    return [img, flipped, rotated]

# Apply augmentation
for root, _, files in os.walk(sketch_dir):
    for file in files:
        img_path = os.path.join(sketch_dir, file)
        augmented_images = augment_sketch(img_path)

        if augmented_images:
            for i, aug_img in enumerate(augmented_images):
                aug_filename = f"{os.path.splitext(file)[0]}_aug{i}.png"
                cv2.imwrite(os.path.join(output_dir, aug_filename), aug_img)


In [None]:
old_dir = 'datasets/fscoco/processed_sketches'
new_dir = 'datasets/fscoco/processed_sub'

# Renaming the directory
os.rename(old_dir, new_dir)

In [None]:
sketch_files = os.listdir(output_dir_sub)
train_files, temp_files = train_test_split(sketch_files, test_size=0.2, random_state=42)
val_files, test_files = train_test_split(temp_files, test_size=0.5, random_state=42)

In [None]:
import shutil
import os

# Move files to separate directories
train_dir = "datasets/fscoco/processed_sketches/train"
val_dir = "datasets/fscoco/processed_sketches/val"
test_dir = "datasets/fscoco/processed_sketches/test"

os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

def move_files_to_target(files, target_dir):
    for file in files:
        file_path = os.path.join(output_dir_sub, file)
        shutil.move(file_path, os.path.join(target_dir, file))

move_files_to_target(train_files, train_dir)
move_files_to_target(val_files, val_dir)
move_files_to_target(test_files, test_dir)

print("Train/Validation/Test split completed.")


Train/Validation/Test split completed.


In [None]:
# move png images, get rid of all the subfolders
import os
import shutil

processed_sketches_dir = "datasets/fscoco/processed_sketches"

dirs = ['train', 'val', 'test']

for dir_name in dirs:
    dir_path = os.path.join(processed_sketches_dir, dir_name)

    if os.path.isdir(dir_path):
        # Walk through the subdirectories and move the .png files
        for root, dirs, files in os.walk(dir_path):
            for file in files:
                if file.endswith(".png"):
                    old_path = os.path.join(root, file)
                    new_path = os.path.join(dir_path, file)

                    shutil.move(old_path, new_path)

        for root, dirs, files in os.walk(dir_path, topdown=False):
            for dir_name in dirs:
                subfolder_path = os.path.join(root, dir_name)
                if not os.listdir(subfolder_path):
                    shutil.rmtree(subfolder_path)

    else:
        print(f"Directory {dir_path} does not exist.")

print("All PNG files have been moved and subfolders removed.")

All PNG files have been moved and subfolders removed.


In [None]:
# move txt files, get rid of all the subfolders

text_dir = "datasets/fscoco/text"

for subdir in os.listdir(text_dir):
    subdir_path = os.path.join(text_dir, subdir)

    if os.path.isdir(subdir_path):
        for file in os.listdir(subdir_path):
            file_path = os.path.join(subdir_path, file)

            if file.endswith(".txt"):
                new_path = os.path.join(text_dir, file)
                if os.path.exists(new_path):
                    print(f"Skipping {file} - already exists in {text_dir}")
                else:
                    shutil.move(file_path, new_path)

        os.rmdir(subdir_path)
        print(f"Removed empty directory: {subdir_path}")

print("All txt files have been moved, and subfolders have been deleted.")

Removed empty directory: datasets/fscoco/text/73
Removed empty directory: datasets/fscoco/text/19
Removed empty directory: datasets/fscoco/text/77
Removed empty directory: datasets/fscoco/text/41
Removed empty directory: datasets/fscoco/text/16
Removed empty directory: datasets/fscoco/text/9
Removed empty directory: datasets/fscoco/text/60
Removed empty directory: datasets/fscoco/text/92
Removed empty directory: datasets/fscoco/text/24
Removed empty directory: datasets/fscoco/text/96
Removed empty directory: datasets/fscoco/text/89
Removed empty directory: datasets/fscoco/text/2
Removed empty directory: datasets/fscoco/text/94
Removed empty directory: datasets/fscoco/text/74
Removed empty directory: datasets/fscoco/text/10
Removed empty directory: datasets/fscoco/text/79
Removed empty directory: datasets/fscoco/text/86
Removed empty directory: datasets/fscoco/text/85
Removed empty directory: datasets/fscoco/text/54
Removed empty directory: datasets/fscoco/text/39
Removed empty director

In [None]:
from collections import Counter

def build_vocab(caption_dir):
    vocab_counter = Counter()

    # Read all captions
    for caption_file in os.listdir(caption_dir):
        caption_path = os.path.join(caption_dir, caption_file)
        if caption_file.endswith('.txt'):
            with open(caption_path, 'r') as file:
                caption = file.read().strip().lower()
                vocab_counter.update(caption.split())

    # build vocab mapping
    vocab = {word: idx + 1 for idx, (word, _) in enumerate(vocab_counter.most_common())}
    vocab["<PAD>"] = 0
    vocab["<UNK>"] = len(vocab)

    return vocab

caption_dir = "datasets/fscoco/text"
vocab = build_vocab(caption_dir)
print(f"Vocabulary size: {len(vocab)}")


Vocabulary size: 2838


In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import os
import sys

# 1. CUDA Initialization Safety
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Always enable for precise error location

def train_doodle_model(train_dir, val_dir, test_dir, caption_dir, vocab, num_epochs=5):
    print("Initializing training with CUDA safety checks")

    # 2. Device Configuration
    device = torch.device("cpu")
    if torch.cuda.is_available():
        try:
            # Test CUDA with a simple operation before committing
            _ = torch.cuda.FloatTensor(1).zero_()
            device = torch.device("cuda")
            print("CUDA initialized successfully")
        except RuntimeError as e:
            print(f"CUDA failed: {e}. Permanent fallback to CPU")
            torch.cuda.is_available = lambda: False  # Disable future CUDA checks

    # 3. Vocabulary Finalization
    VOCAB_SIZE = len(vocab) + 4  # Original vocab + special tokens
    print(f"Final vocabulary size: {VOCAB_SIZE}")

    # 4. Data Pipeline with Double Safety
    class SafeDoodleDataset(DoodleDataset):
        def __getitem__(self, idx):
            try:
                img, caption = super().__getitem__(idx)
                # Double index validation
                caption = torch.clamp(caption, 0, VOCAB_SIZE-1)
                return img, caption
            except Exception as e:
                print(f"Error in sample {idx}: {e}")
                # Return valid dummy data
                return torch.rand(3, 224, 224), torch.tensor([1, 2], dtype=torch.long)

    transform = transforms.Compose([
        transforms.Grayscale(3),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # 5. Model Architecture Verification
    class DoodleStoryModel(nn.Module):
        def __init__(self, vocab_size, embed_size=256, hidden_size=512, num_layers=2):
            super().__init__()
            # Architecture must match exactly with VOCAB_SIZE
            self.embed = nn.Embedding(vocab_size, embed_size)
            self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
            self.fc = nn.Linear(hidden_size, vocab_size)  # Critical: output size matches vocab

        def forward(self, images, captions):
            # Your existing forward logic
            embedded = self.embed(captions)
            output, _ = self.lstm(embedded)
            return self.fc(output)

    # 6. Data Loading with Enhanced Validation
    train_dataset = SafeDoodleDataset(train_dir, caption_dir, transform,
                                     lambda x: x.split(), vocab)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                             collate_fn=collate_fn, pin_memory=False)

    # 7. Model Initialization
    model = DoodleStoryModel(vocab_size=VOCAB_SIZE).to(device)
    print(f"Model structure verified - Output size: {VOCAB_SIZE}")

    # 8. Training Loop with CUDA Isolation
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    try:
        for epoch in range(num_epochs):
            model.train()
            for batch_idx, (images, captions) in enumerate(train_loader):
                # Final index sanitization
                captions = torch.clamp(captions, 0, VOCAB_SIZE-1)

                # Device transfer with validation
                images = images.to(device, non_blocking=False)
                captions = captions.to(device, non_blocking=False)

                # Forward pass with dimension verification
                outputs = model(images, captions[:, :-1])
                assert outputs.shape[-1] == VOCAB_SIZE, "Model output mismatch!"

                # Loss calculation
                loss = criterion(outputs.view(-1, VOCAB_SIZE),
                              captions[:, 1:].contiguous().view(-1))

                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                optimizer.step()

                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

    except RuntimeError as e:
        print(f"Critical error: {e}")
        if 'CUDA' in str(e):
            print("Executing emergency CPU fallback")
            device = torch.device("cpu")
            model = model.to(device)
            images = images.to(device)
            captions = captions.to(device)
            # Continue training on CPU if possible

    return model, vocab

In [None]:
# Execute training
model, vocab = train_doodle_model(train_dir, val_dir, test_dir, caption_dir, vocab)

Initializing training with CUDA safety checks
CUDA initialized successfully
Final vocabulary size: 2842
Model structure verified - Output size: 2842
Epoch 1, Batch 0, Loss: 7.9565
Epoch 1, Batch 1, Loss: 7.9329
Epoch 1, Batch 2, Loss: 7.8957
Epoch 1, Batch 3, Loss: 7.8528
Epoch 1, Batch 4, Loss: 7.7320
Epoch 1, Batch 5, Loss: 7.5625
Epoch 1, Batch 6, Loss: 7.2219
Epoch 1, Batch 7, Loss: 6.8427
Epoch 1, Batch 8, Loss: 6.7060
Epoch 1, Batch 9, Loss: 6.6099
Epoch 1, Batch 10, Loss: 6.4728
Epoch 1, Batch 11, Loss: 6.2870
Epoch 1, Batch 12, Loss: 6.5887
Epoch 1, Batch 13, Loss: 5.7408
Epoch 1, Batch 14, Loss: 5.8980
Epoch 1, Batch 15, Loss: 5.7426
Epoch 1, Batch 16, Loss: 5.9880
Epoch 1, Batch 17, Loss: 5.7926
Epoch 1, Batch 18, Loss: 5.6713
Epoch 1, Batch 19, Loss: 5.8017
Epoch 1, Batch 20, Loss: 5.5397
Epoch 1, Batch 21, Loss: 5.8246
Epoch 1, Batch 22, Loss: 5.3548
Epoch 1, Batch 23, Loss: 5.8261
Epoch 1, Batch 24, Loss: 5.4494
Epoch 1, Batch 25, Loss: 5.7166
Epoch 1, Batch 26, Loss: 5.38