## 步骤1: GPU检查和Drive挂载

In [None]:
# 检查GPU
!nvidia-smi

import torch
print(f"CUDA可用: {torch.cuda.is_available()}")
print(f"GPU型号: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

In [None]:
# 挂载Google Drive
from google.colab import drive
drive.mount('/content/drive')

## 步骤2: 下载/上传代码

**选项A**: 从GitHub clone
```bash
!git clone https://github.com/YOUR_USERNAME/tess-diffusion.git
```

**选项B**: 从Drive加载(假设已上传)

In [None]:
# 选项A: Clone from GitHub
# !git clone https://github.com/allenai/tess-diffusion.git
# %cd tess-diffusion

# 选项B: 从Drive加载
%cd /content/drive/MyDrive/tess-diffusion

# 检查文件
!ls -lh *.txt *.py

## 步骤3: 安装依赖 (~5分钟)

In [None]:
# 安装PyTorch和依赖
!pip install -q torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 \
    --extra-index-url https://download.pytorch.org/whl/cu113

!pip install -q transformers==4.25.1 diffusers==0.7.2 datasets==2.14.6 \
    accelerate==0.12.0 tensorboard scipy scikit-learn nltk \
    sacrebleu evaluate bert_score

print("✓ 依赖安装完成")

In [None]:
# 验证安装
import torch
import transformers
import diffusers

print(f"PyTorch: {torch.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"Diffusers: {diffusers.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

## 步骤4: 扩展Tokenizer词汇表 (~3分钟)

In [None]:
# 检查数据文件
!head -n 3 tess_train1_oneline.txt
!wc -l tess_*.txt

In [None]:
# 扩展tokenizer - 关键步骤!
!python extend_tokenizer_vocab.py \
    --train_file tess_train1_oneline.txt \
    --base_model roberta-base \
    --output_dir extended_tokenizer

# 查看统计
!cat extended_tokenizer/vocab_extension_stats.json

In [None]:
# 验证tokenization
!python validate_config.py \
    --checkpoint extended_tokenizer \
    --config configs/tess_gpu_oneline_sc.json \
    --train_file tess_train1_oneline.txt \
    --check_tokenization \
    --num_sample_entities 50

## 步骤5: 配置训练参数

In [None]:
import json

# 读取默认配置
with open('configs/tess_gpu_oneline_sc.json', 'r') as f:
    config = json.load(f)

# 修改为Colab适配配置
config.update({
    'tokenizer_name': 'extended_tokenizer',
    'output_dir': '/content/drive/MyDrive/tess_outputs',
    'per_device_train_batch_size': 8,
    'per_device_eval_batch_size': 8,
    'num_train_epochs': 3,  # 快速验证改为1
    'save_steps': 500,
    'eval_steps': 500,
    'logging_steps': 50,
    'fp16': True,
    'time_save_interval_seconds': 1800,
    'gdrive_backup_dir': '/content/drive/MyDrive/tess_backups',
    'backup_keep_last': 2,
})

# 保存配置
with open('configs/tess_colab.json', 'w') as f:
    json.dump(config, f, indent=2)

print("✓ 配置已更新")
print(json.dumps(config, indent=2))

## 步骤6: 训练模型

### 选项A: 快速验证 (1 epoch, ~2小时)

In [None]:
# 快速训练1个epoch
!python run_mlm.py \
    --model_name_or_path roberta-base \
    --tokenizer_name extended_tokenizer \
    --train_file tess_train1_oneline.txt \
    --validation_file tess_valid1_oneline.txt \
    --output_dir /content/drive/MyDrive/tess_outputs_quick \
    --line_by_line True \
    --max_seq_length 256 \
    --pad_to_max_length True \
    --per_device_train_batch_size 8 \
    --num_train_epochs 1 \
    --save_steps 1000 \
    --eval_steps 1000 \
    --logging_steps 50 \
    --fp16 True \
    --simplex_value 5 \
    --num_diffusion_steps 500 \
    --self_condition logits_addition \
    --self_condition_zeros_after_softmax True \
    --overwrite_output_dir True

### 选项B: 完整训练 (3 epochs, ~6-7小时)

In [None]:
# 完整训练3个epoch
!python run_mlm.py configs/tess_colab.json

## 步骤7: 监控训练 (在训练时运行)

In [None]:
# 启动TensorBoard
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/tess_outputs

In [None]:
# 查看最新checkpoints
!ls -lht /content/drive/MyDrive/tess_outputs/checkpoint-* | head -5

## 步骤8: 评测模型

### 快速评测 (200 queries, ~5分钟)

In [None]:
# 找到最新checkpoint
import os
checkpoint_dirs = [d for d in os.listdir('/content/drive/MyDrive/tess_outputs') if d.startswith('checkpoint-')]
if checkpoint_dirs:
    latest_checkpoint = sorted(checkpoint_dirs, key=lambda x: int(x.split('-')[1]))[-1]
    checkpoint_path = f'/content/drive/MyDrive/tess_outputs/{latest_checkpoint}'
    print(f"使用checkpoint: {checkpoint_path}")
else:
    print("未找到checkpoint")

In [None]:
# 快速评测
!python run_optimized_eval.py \
    --checkpoint {checkpoint_path} \
    --mode tail \
    --quick

### Grid Search 最优参数 (~20分钟)

In [None]:
# Grid search找最优tess_t_eval
!python run_optimized_eval.py \
    --checkpoint {checkpoint_path} \
    --grid_search \
    --num_queries 500

### 完整评测 (2000 queries, ~40分钟)

In [None]:
# Tail预测 (给定h,r预测t)
!python run_optimized_eval.py \
    --checkpoint {checkpoint_path} \
    --mode tail \
    --num_queries 2000 \
    --tess_t_eval 60 \
    --neg_k 128 \
    --output eval_tail_results.json

In [None]:
# Head预测 (给定r,t预测h)
!python run_optimized_eval.py \
    --checkpoint {checkpoint_path} \
    --mode head \
    --num_queries 2000 \
    --tess_t_eval 60 \
    --neg_k 128 \
    --output eval_head_results.json

## 步骤9: 查看结果

In [None]:
import json

# 读取结果
with open('eval_tail_results.json', 'r') as f:
    tail_results = json.load(f)

with open('eval_head_results.json', 'r') as f:
    head_results = json.load(f)

# 显示结果
print("=" * 60)
print("最终评测结果")
print("=" * 60)
print(f"\nTail预测 (给定h,r预测t):")
print(f"  MRR: {tail_results['MRR']:.4f}")
print(f"  Hits@1: {tail_results['Hits@1']:.4f}")
print(f"  Hits@3: {tail_results['Hits@3']:.4f}")
print(f"  Hits@10: {tail_results['Hits@10']:.4f}")

print(f"\nHead预测 (给定r,t预测h):")
print(f"  MRR: {head_results['MRR']:.4f}")
print(f"  Hits@1: {head_results['Hits@1']:.4f}")
print(f"  Hits@3: {head_results['Hits@3']:.4f}")
print(f"  Hits@10: {head_results['Hits@10']:.4f}")

print("\n" + "=" * 60)

# 性能提升
baseline_tail_mrr = 0.167
improvement = (tail_results['MRR'] - baseline_tail_mrr) / baseline_tail_mrr * 100
print(f"\nTail MRR提升: {improvement:.1f}%")
print(f"修复前: {baseline_tail_mrr:.4f}")
print(f"修复后: {tail_results['MRR']:.4f}")

## 步骤10: 保存结果到Drive

In [None]:
# 创建结果目录
!mkdir -p /content/drive/MyDrive/tess_final_results

# 复制评测结果
!cp eval_*.json /content/drive/MyDrive/tess_final_results/

# 复制最佳checkpoint
!cp -r {checkpoint_path} /content/drive/MyDrive/tess_final_results/best_checkpoint

print("✓ 结果已保存到 /content/drive/MyDrive/tess_final_results/")

---

## 常见问题

### 1. 训练中断怎么办?
从checkpoint恢复:
```python
!python run_mlm.py \
    --resume_from_checkpoint /content/drive/MyDrive/tess_outputs/checkpoint-3000 \
    configs/tess_colab.json
```

### 2. 内存不足?
减小batch size:
```python
config['per_device_train_batch_size'] = 4
config['gradient_accumulation_steps'] = 2
```

### 3. 评测结果仍然低?
- 运行grid search找最优tess_t_eval
- 检查实体是否被正确tokenize
- 确认训练loss < 2.0

---

## 预期性能

修复后预期:
- **Tail MRR**: 35-45% (原16.7%)
- **Tail Hits@10**: 55-65% (原34.7%)
- **训练时间**: 6-7小时 (3 epochs)