# HMCAN Training on Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sucpark/hmcan/blob/main/notebooks/train_hmcan_colab.ipynb)

Hierarchical Multichannel CNN-based Attention Network for Document Classification

## Phase 1: Foundation Models (HAN, HCAN, HMCAN)

## 1. Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Clone repository
!git clone https://github.com/sucpark/hmcan.git
%cd hmcan

In [None]:
# Install dependencies
!pip install -e . -q
!pip install wandb -q

In [None]:
# Download NLTK data
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')

## 2. Weights & Biases Setup (Optional)

In [None]:
# Login to wandb (optional but recommended)
import wandb
wandb.login()

## 3. Download Data

In [None]:
# Download Yelp dataset and GloVe embeddings
# --max-samples: Number of samples to use (reduce for faster experiments)
!python scripts/download_data.py --max-samples 10000

## 4. Configuration

In [None]:
# View default HMCAN config
!cat configs/hmcan.yaml

In [None]:
# Modify config if needed (enable wandb)
import yaml

with open('configs/hmcan.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Enable wandb logging
config['use_wandb'] = True
config['use_tensorboard'] = True

# Save modified config
with open('configs/hmcan_colab.yaml', 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print("Config saved to configs/hmcan_colab.yaml")

## 5. Train Models

### 5.1 Train HMCAN (Main Model)

In [None]:
!python -m hmcan train --config configs/hmcan_colab.yaml

### 5.2 Train HAN (Baseline)

In [None]:
# Modify HAN config for wandb
with open('configs/han.yaml', 'r') as f:
    han_config = yaml.safe_load(f)

han_config['use_wandb'] = True
han_config['use_tensorboard'] = True

with open('configs/han_colab.yaml', 'w') as f:
    yaml.dump(han_config, f, default_flow_style=False)

!python -m hmcan train --config configs/han_colab.yaml

### 5.3 Train HCAN

In [None]:
# Modify HCAN config for wandb
with open('configs/hcan.yaml', 'r') as f:
    hcan_config = yaml.safe_load(f)

hcan_config['use_wandb'] = True
hcan_config['use_tensorboard'] = True

with open('configs/hcan_colab.yaml', 'w') as f:
    yaml.dump(hcan_config, f, default_flow_style=False)

!python -m hmcan train --config configs/hcan_colab.yaml

## 6. Evaluate Models

In [None]:
# Evaluate HMCAN
!python -m hmcan evaluate --checkpoint outputs/hmcan_yelp/checkpoints/best_model.pt

In [None]:
# Evaluate HAN
!python -m hmcan evaluate --checkpoint outputs/han_yelp/checkpoints/best_model.pt

In [None]:
# Evaluate HCAN
!python -m hmcan evaluate --checkpoint outputs/hcan_yelp/checkpoints/best_model.pt

## 7. Save Results to Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Copy outputs to Drive
!cp -r outputs /content/drive/MyDrive/hmcan_outputs
print("Outputs saved to Google Drive!")

## 8. Results Summary

In [None]:
import os
import torch

models = ['han', 'hcan', 'hmcan']
results = {}

for model in models:
    ckpt_path = f'outputs/{model}_yelp/checkpoints/best_model.pt'
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location='cpu')
        results[model] = ckpt.get('metrics', {})

print("=" * 50)
print("Results Summary")
print("=" * 50)
for model, metrics in results.items():
    acc = metrics.get('accuracy', 'N/A')
    if isinstance(acc, float):
        acc = f"{acc*100:.2f}%"
    print(f"{model.upper():8s}: {acc}")
print("=" * 50)

## 9. Attention Visualization

In [None]:
import torch
import matplotlib.pyplot as plt
from hmcan.models import HMCAN
from hmcan.data import YelpDataModule

# Load data
data_module = YelpDataModule(data_dir='data')
data_module.setup()

# Load model
model = HMCAN(
    vocab_size=len(data_module.vocabulary),
    pretrained_embeddings=data_module.pretrained_embeddings,
)
ckpt = torch.load('outputs/hmcan_yelp/checkpoints/best_model.pt', map_location='cpu')
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

print("Model loaded successfully!")

In [None]:
# Get a sample and visualize attention
test_loader = data_module.test_dataloader()
batch = next(iter(test_loader))

with torch.no_grad():
    outputs = model(batch['document'], batch['sentence_lengths'])

# Sentence attention visualization
sent_attn = outputs['sentence_attention'].squeeze().numpy()

plt.figure(figsize=(10, 4))
plt.bar(range(len(sent_attn)), sent_attn)
plt.xlabel('Sentence Index')
plt.ylabel('Attention Weight')
plt.title('Sentence-level Attention Weights')
plt.tight_layout()
plt.savefig('sentence_attention.png', dpi=150)
plt.show()