In [None]:
import torch
import pickle

# 1. Charger le modèle
checkpoint = torch.load('medical_model.pth', map_location='cpu')
print(f"Modèle chargé. Clés: {list(checkpoint.keys())}")

# 2. Charger vos séquences
with open('medical_sequences_pure.pkl', 'rb') as f:
    data = pickle.load(f)

sequences = data['sequences']

# 3. Créer modèle simple
class QuickModel(torch.nn.Module):
    def __init__(self, vocab_size=1000, embed_dim=128):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, embed_dim)
        self.pool = torch.nn.AdaptiveAvgPool1d(1)
    
    def forward(self, x):
        # x: [batch, seq_len]
        e = self.embed(x)          # [batch, seq_len, dim]
        e = e.transpose(1, 2)      # [batch, dim, seq_len]
        return self.pool(e).squeeze(-1)  # [batch, dim]

# 4. Utiliser
model = QuickModel()
if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)

model.eval()

# 5. Exemple avec un patient
for pid, seq in list(sequences.items())[:3]:
    print(f"\nPatient {pid}:")
    print(f"  Séquence: {seq[:5]}...")
    
    # Encoder simplement (à adapter)
    encoded = [hash(code) % 100 for code in seq[:20]]
    if len(encoded) < 20:
        encoded += [0] * (20 - len(encoded))
    
    # Embedding
    with torch.no_grad():
        x = torch.tensor([encoded], dtype=torch.long)
        emb = model(x)[0].numpy()
    
    print(f"  Embedding shape: {emb.shape}")