# Chainable Markov Chain Model

Trains a model that learns a chainable composition operation in latent space for Markov chain prediction.

## Setup: Mount Drive and Clone Repo

In [None]:
# Google Drive mounting (commented out - using local storage instead)
# from google.colab import drive
# drive.mount('/content/drive')

# Use local output directory (persists only during runtime)
import os
OUTPUT_DIR = '/content/out'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Outputs will be saved to: {OUTPUT_DIR}")

In [None]:
!git clone https://github.com/sughodke/markov-learned.git
%cd markov-learned

In [None]:
!pip install wandb -q

In [None]:
import torch
from torch.utils.data import DataLoader
from model import (
    CharVocab,
    NgramDataset,
    ChainableMarkovModel,
    collate_ngrams,
    train,
    generate,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Initialize Weights & Biases

In [None]:
import wandb
from google.colab import userdata

# Login with API key from Colab secrets
wandb.login(key=userdata.get('WANDB_API_KEY'))

# Hyperparameters
config = {
    'd_latent': 128,
    'd_hidden': 512,
    'dropout': 0.1,
    'batch_size': 256,  # Increased for faster training
    'epochs': 10,       # Reduced from 50
    'lr': 3e-4,
    'weight_decay': 0.01,
    'gradient_clip': 1.0,
}

wandb.init(
    project='chainable-markov',
    config=config,
    name='shakespeare-run-v3',
)

## Load Data

In [None]:
with open('data/shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Corpus size: {len(text):,} characters")

vocab = CharVocab(text)
print(f"Vocabulary size: {vocab.vocab_size}")

# Train/validation split
split_idx = int(len(text) * 0.9)
train_text = text[:split_idx]
val_text = text[split_idx:]

train_dataset = NgramDataset(train_text, vocab)
val_dataset = NgramDataset(val_text, vocab)
print(f"Train samples: {len(train_dataset):,}")
print(f"Val samples: {len(val_dataset):,}")

## Create Model

In [None]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=config['batch_size'], 
    shuffle=True, 
    collate_fn=collate_ngrams, 
    num_workers=2
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=config['batch_size'], 
    shuffle=False, 
    collate_fn=collate_ngrams, 
    num_workers=2
)

model = ChainableMarkovModel(
    vocab_size=vocab.vocab_size, 
    d_latent=config['d_latent'], 
    d_hidden=config['d_hidden'], 
    dropout=config['dropout']
)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")
wandb.config.update({'num_params': num_params})

## Train

In [None]:
model, history = train(
    model, 
    train_loader, 
    val_loader, 
    device, 
    epochs=config['epochs'],
    lr=config['lr'],
    weight_decay=config['weight_decay'],
    gradient_clip=config['gradient_clip'],
    use_wandb=True,
    vocab=vocab,  # Enable sample generation each epoch
    sample_seed="The ",
)

## Save Model to Google Drive

In [None]:
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
save_path = f"{OUTPUT_DIR}/markov_model_{timestamp}.pt"

torch.save({
    'model_state_dict': model.state_dict(),
    'vocab_char_to_idx': vocab.char_to_idx,
    'vocab_idx_to_char': vocab.idx_to_char,
    'd_latent': config['d_latent'],
    'd_hidden': config['d_hidden'],
    'history': history,
}, save_path)

print(f"Model saved to: {save_path}")

# Log model artifact to wandb
artifact = wandb.Artifact('markov-model', type='model')
artifact.add_file(save_path)
wandb.log_artifact(artifact)

## Generate Text

In [None]:
seed = "Follow those"
print(f"Seed: '{seed}'")
print("-" * 40)
generated_text = generate(model, vocab, seed, max_length=200, temperature=0.8, device=device)
print(generated_text)

# Log generated text to wandb
wandb.log({'generated_text': wandb.Html(f'<pre>{generated_text}</pre>')})

# Save generated text
with open(f"{OUTPUT_DIR}/generated_{timestamp}.txt", 'w') as f:
    f.write(generated_text)
print(f"\nGenerated text saved to Drive")

## Chainability Test

In [None]:
model.eval()
with torch.no_grad():
    for n in [2, 3, 4, 5]:
        seq = vocab.encode("a" * n)
        latent = model.forward_chain([seq], device)
        print(f"{n}-gram: latent shape = {latent.shape}")

## Finish W&B Run

In [None]:
wandb.finish()

## Load Saved Model (for future sessions)

In [None]:
# Uncomment and modify path to load a previously saved model
# checkpoint = torch.load(f"{OUTPUT_DIR}/markov_model_YYYYMMDD_HHMMSS.pt")
# 
# vocab = CharVocab.__new__(CharVocab)
# vocab.char_to_idx = checkpoint['vocab_char_to_idx']
# vocab.idx_to_char = checkpoint['vocab_idx_to_char']
# vocab.vocab_size = len(vocab.char_to_idx)
# 
# model = ChainableMarkovModel(
#     vocab_size=vocab.vocab_size,
#     d_latent=checkpoint['d_latent'],
#     d_hidden=checkpoint['d_hidden'],
# )
# model.load_state_dict(checkpoint['model_state_dict'])
# model = model.to(device)
# print("Model loaded!")