# Chainable Markov Chain Model

Trains a model that learns a chainable composition operation in latent space for Markov chain prediction.

## Setup: Clone repo and import

In [None]:
!git clone https://github.com/sughodke/markov-learned.git
%cd markov-learned

In [None]:
import torch
from torch.utils.data import DataLoader
from model import (
    CharVocab,
    NgramDataset,
    ChainableMarkovModel,
    collate_ngrams,
    train,
    generate,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Load Data

In [None]:
with open('data/shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Corpus size: {len(text):,} characters")

vocab = CharVocab(text)
print(f"Vocabulary size: {vocab.vocab_size}")

# Train/validation split
split_idx = int(len(text) * 0.9)
train_text = text[:split_idx]
val_text = text[split_idx:]

train_dataset = NgramDataset(train_text, vocab)
val_dataset = NgramDataset(val_text, vocab)
print(f"Train samples: {len(train_dataset):,}")
print(f"Val samples: {len(val_dataset):,}")

## Create Model

In [None]:
# Hyperparameters
d_latent = 128
d_hidden = 512
dropout = 0.1
batch_size = 128
epochs = 50

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_ngrams, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_ngrams, num_workers=2)

model = ChainableMarkovModel(vocab_size=vocab.vocab_size, d_latent=d_latent, d_hidden=d_hidden, dropout=dropout)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Train

In [None]:
model = train(model, train_loader, val_loader, device, epochs=epochs)

## Generate Text

In [None]:
seed = "Follow those"
print(f"Seed: '{seed}'")
print("-" * 40)
print(generate(model, vocab, seed, max_length=200, temperature=0.8, device=device))

## Chainability Test

In [None]:
model.eval()
with torch.no_grad():
    for n in [2, 3, 4, 5]:
        seq = vocab.encode("a" * n)
        latent = model.forward_chain([seq], device)
        print(f"{n}-gram: latent shape = {latent.shape}")