# sgkit GWAS tutorial in Google Colab (classroom version)
This notebook adapts the official sgkit GWAS tutorial to run smoothly in Colab.

**What students will do**
1) Install sgkit with VCF support
2) Download a small public 1000 Genomes VCF (≈20MB) and sample annotations
3) Convert VCF → Zarr (faster downstream)
4) Run a toy GWAS for the provided `CaffeineConsumption` trait
5) Show why population structure confounds GWAS, then correct it using PCA covariates

Source data and overall flow follow the sgkit GWAS tutorial. citeturn1view0

In [None]:
# Colab setup
# VCF support is an "extra" in sgkit (installs cyvcf2). citeturn10search0
!pip -q install 'sgkit[vcf]'

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import sgkit as sg
from sgkit.io.vcf import vcf_to_zarr

print("sgkit version:", sg.__version__)
xr.set_options(display_expand_attrs=False, display_expand_data_vars=True);

## 1) Download public toy data
We use the same small public 1000 Genomes subset that the sgkit tutorial uses. citeturn1view0

In [None]:
from pathlib import Path
import requests

VCF_URL = "https://storage.googleapis.com/sgkit-gwas-tutorial/1kg.vcf.bgz"
VCF_PATH = Path("1kg.vcf.bgz")

if not VCF_PATH.exists():
    print("Downloading VCF...")
    r = requests.get(VCF_URL, stream=True)
    r.raise_for_status()
    with open(VCF_PATH, "wb") as f:
        for chunk in r.iter_content(chunk_size=1<<20):
            if chunk:
                f.write(chunk)
    print("Saved:", VCF_PATH, "bytes:", VCF_PATH.stat().st_size)
else:
    print("VCF already exists:", VCF_PATH, "bytes:", VCF_PATH.stat().st_size)

ANNOTATIONS_URL = "https://storage.googleapis.com/sgkit-gwas-tutorial/1kg_annotations.txt"
df_anno = pd.read_csv(ANNOTATIONS_URL, sep="\t", index_col="Sample")
df_anno.head()

## 2) Convert VCF → Zarr and load
Zarr is a chunked format that makes downstream operations much faster. citeturn1view0

In [None]:
ZARR_PATH = Path("1kg.zarr")

if not ZARR_PATH.exists():
    print("Converting VCF → Zarr (first run only)...")
    vcf_to_zarr(
        str(VCF_PATH),
        str(ZARR_PATH),
        max_alt_alleles=1,
        fields=["FORMAT/GT", "FORMAT/DP", "FORMAT/GQ", "FORMAT/AD"],
        field_defs={"FORMAT/AD": {"Number": "R"}},
    )
else:
    print("Zarr already exists:", ZARR_PATH)

ds = sg.load_dataset(str(ZARR_PATH))
ds

## 3) Attach sample annotations
We join the sample annotations (sex, superpopulation, caffeine intake, etc.) onto the genotype dataset. citeturn8view3

In [None]:
ds_anno = df_anno.to_xarray().rename({"Sample": "samples"})
ds = ds.set_index({"samples": "sample_id"})
ds = ds.merge(ds_anno, join="left")
ds = ds.reset_index("samples").reset_coords(drop=True)

# Keep only samples with non-missing trait values (defensive)
ds = ds.sel(samples=~xr.ufuncs.isnan(ds["CaffeineConsumption"]))
print("Samples:", int(ds.dims["samples"]), "Variants:", int(ds.dims["variants"]))
ds[["sample_id", "SuperPopulation", "isFemale", "CaffeineConsumption"]].isel(samples=slice(5))

## 4) Basic variant QC (MAF + HWE)
This is a minimal QC step for teaching purposes.
We compute allele frequencies and Hardy-Weinberg p-values, then keep common-ish variants. citeturn13view0turn12view1

In [None]:
# Variant summary statistics
ds = sg.variant_stats(ds)  # adds variant_allele_frequency, call rate, etc. citeturn12view0
ds = sg.hardy_weinberg_test(ds)  # adds variant_hwe_p_value citeturn12view1

# Filter: AF > 1% and HWE p > 1e-6 (same thresholds used in the tutorial) citeturn13view0
af_alt = ds["variant_allele_frequency"][:, 1]
ds = ds.sel(variants=(af_alt > 0.01) & (ds["variant_hwe_p_value"] > 1e-6))

print("After QC -> Samples:", int(ds.dims["samples"]), "Variants:", int(ds.dims["variants"]))

## 5) GWAS: linear regression per SNP
We regress the phenotype on genotype dosage (0/1/2). citeturn9view3

In [None]:
# Dosage = number of alternate alleles
ds["call_dosage"] = ds.call_genotype.sum(dim="ploidy")

# GWAS without population-structure covariates (intentionally confounded)
ds_lr0 = sg.gwas_linear_regression(
    ds,
    dosage="call_dosage",
    add_intercept=True,
    covariates=[],
    traits=["CaffeineConsumption"],
)
ds_lr0

In [None]:
# Simple QQ + Manhattan plots (matplotlib only)
import math

def qq_plot(pvals, title="QQ plot"):
    p = np.asarray(pvals, dtype=float)
    p = np.clip(p, np.finfo(float).tiny, 1.0)
    p.sort()
    n = len(p)
    exp = -np.log10((np.arange(1, n + 1) - 0.5) / n)
    obs = -np.log10(p)
    mx = math.ceil(max(exp.max(), obs.max()))

    plt.figure(figsize=(5, 5))
    plt.scatter(exp, obs, s=6)
    plt.plot([0, mx], [0, mx], linestyle="--")
    plt.xlim(0, mx); plt.ylim(0, mx)
    plt.xlabel("Expected -log10(p)")
    plt.ylabel("Observed -log10(p)")
    plt.title(title)
    plt.tight_layout()
    plt.show()

def manhattan_plot(ds_lr, title="Manhattan plot", genomewide_line=5e-8):
    # Use contig name if present; otherwise fall back to numeric contig codes
    if "variant_contig_name" in ds_lr:
        chr_name = ds_lr["variant_contig_name"].values
    else:
        chr_name = ds_lr["variant_contig"].values.astype(str)

    pos = ds_lr["variant_position"].values
    p = ds_lr["variant_linreg_p_value"].squeeze().values
    p = np.clip(p, np.finfo(float).tiny, 1.0)
    mlogp = -np.log10(p)

    df = pd.DataFrame({"CHR": chr_name, "BP": pos, "MLOGP": mlogp})
    # sort chromosomes roughly numerically where possible
    def chr_key(x):
        try:
            return (0, int(x))
        except Exception:
            return (1, str(x))
    chr_order = sorted(df["CHR"].unique(), key=chr_key)
    offsets = {}
    cur = 0
    ticks = []
    ticklabels = []
    for c in chr_order:
        m = df.loc[df["CHR"] == c, "BP"].max()
        offsets[c] = cur
        ticks.append(cur + m / 2)
        ticklabels.append(c)
        cur += m + 1_000_000  # gap

    df["X"] = df.apply(lambda r: r["BP"] + offsets[r["CHR"]], axis=1)

    plt.figure(figsize=(12, 4))
    plt.scatter(df["X"], df["MLOGP"], s=4)
    plt.axhline(-np.log10(genomewide_line), linestyle="--")
    plt.xticks(ticks, ticklabels, rotation=0)
    plt.xlabel("Chromosome")
    plt.ylabel("-log10(p)")
    plt.title(title)
    plt.tight_layout()
    plt.show()

p0 = ds_lr0["variant_linreg_p_value"].squeeze().values
qq_plot(p0, title="QQ (no covariates; typically inflated)")
manhattan_plot(ds_lr0, title="Manhattan (no covariates; confounded)")

## 6) Correct for ancestry using PCA covariates
We compute PCs from genotype data and include the first few PCs + sex as covariates.
This follows the approach in the sgkit GWAS tutorial. citeturn9view1turn9view2

In [None]:
# PCA on alternate-allele counts
ds_pca = sg.stats.pca.count_call_alternate_alleles(ds)  # citeturn9view1

# Filter variants for PCA: remove variants with missing counts or zero variance (as in tutorial) citeturn9view1
variant_mask = ((ds_pca.call_alternate_allele_count < 0).any(dim="samples")) |                (ds_pca.call_alternate_allele_count.std(dim="samples") <= 0.0)
ds_pca = ds_pca.sel(variants=~variant_mask)

ds_pca = sg.pca(ds_pca)
print(ds_pca.sample_pca_explained_variance_ratio.values[:5])

In [None]:
# Plot PC1 vs PC2, colored by SuperPopulation
pc = ds_pca.sample_pca_projection.values
pop = ds["SuperPopulation"].values  # same sample order
pops = pd.Categorical(pop)

plt.figure(figsize=(6, 5))
for lvl in pops.categories:
    m = (pops == lvl)
    plt.scatter(pc[m, 0], pc[m, 1], s=12, label=str(lvl), alpha=0.8)
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.title("PCA of genotypes (colored by superpopulation)")
plt.legend(title="SuperPopulation", bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.show()

In [None]:
# Copy the first 3 PCs back to the full dataset, then rerun GWAS with covariates citeturn9view2
ds["sample_pca_0"] = (("samples",), ds_pca.sample_pca_projection[:, 0].values)
ds["sample_pca_1"] = (("samples",), ds_pca.sample_pca_projection[:, 1].values)
ds["sample_pca_2"] = (("samples",), ds_pca.sample_pca_projection[:, 2].values)

ds_lr = sg.gwas_linear_regression(
    ds,
    dosage="call_dosage",
    add_intercept=True,
    covariates=["isFemale", "sample_pca_0", "sample_pca_1", "sample_pca_2"],
    traits=["CaffeineConsumption"],
)

p = ds_lr["variant_linreg_p_value"].squeeze().values
qq_plot(p, title="QQ (sex + 3 PCs; typically better calibrated)")
manhattan_plot(ds_lr, title="Manhattan (sex + 3 PCs)")

## 7) Export top hits (optional)
This makes it easy for students to inspect the top associations and download them.

In [None]:
# Make a results table and show top hits
res = pd.DataFrame({
    "CHR": (ds_lr["variant_contig_name"].values if "variant_contig_name" in ds_lr else ds_lr["variant_contig"].values),
    "BP": ds_lr["variant_position"].values,
    "ID": ds_lr["variant_id"].values if "variant_id" in ds_lr else np.arange(ds_lr.dims["variants"]),
    "BETA": ds_lr["variant_linreg_beta"].squeeze().values,
    "T": ds_lr["variant_linreg_t_value"].squeeze().values,
    "P": ds_lr["variant_linreg_p_value"].squeeze().values,
})
res = res.sort_values("P")
res.head(10)

In [None]:
# Save as TSV in the Colab VM (download from the left panel if needed)
out = "sgkit_caffeine_gwas_results.tsv"
res.to_csv(out, sep="\t", index=False)
print("Wrote:", out, "rows:", len(res))