Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: copilot-extensions/rag-extension
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: main
Choose a base ref
...
head repository: h2floh/github-copilot-auth-extension
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
Able to merge. These branches can be automatically merged.
  • 6 commits
  • 7 files changed
  • 1 contributor

Commits on Dec 11, 2024

  1. Refactor authentication configuration to support GitHub and EntraID, …

    …updating environment variables and modifying callback paths
    h2floh committed Dec 11, 2024
    Copy the full SHA
    ab5f0ae View commit details
  2. Copy the full SHA
    7ab15d8 View commit details
  3. Copy the full SHA
    a55c4c1 View commit details
  4. Copy the full SHA
    630e35f View commit details
  5. Implement GitHub handle retrieval using access token and add caching …

    …logic for Entra token
    h2floh committed Dec 11, 2024
    Copy the full SHA
    f0fbfc0 View commit details

Commits on Dec 13, 2024

  1. Merge pull request #1 from h2floh/h2floh/combine_entra

    Include additional IdP authentication and token store for retrieval and usage
    h2floh authored Dec 13, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    238b434 View commit details
Showing with 290 additions and 71 deletions.
  1. +5 −2 README.md
  2. +42 −2 agent/service.go
  3. +46 −16 config/info.go
  4. +10 −7 go.mod
  5. +21 −23 go.sum
  6. +29 −6 main.go
  7. +137 −15 oauth/handler.go
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -13,8 +13,11 @@ This project is a Go application that demonstrates how to use retrieval augmente

```
export PORT=8080
export CLIENT_ID=Iv1.0ae52273ad3193eb // the application id
export CLIENT_SECRET="your_client_secret" // generate a new client secret for your application
export GITHUB_CLIENT_ID=Iv1.0ae52273ad3193eb // the application id from GitHub
export GITHUB_CLIENT_SECRET="your_client_secret" // generate a new client secret for your GitHub application
export ENTRA_CLIENT_ID=// the application id from EntraID
export ENTRA_TENANT_ID= // the tenant id from EntraID
export ENTRA_CLIENT_SECRET="your_client_secret" // generate a new client secret for your EntraId application
export FQDN=https://6de513480979.ngrok.app // use ngrok to expose a url
```

44 changes: 42 additions & 2 deletions agent/service.go
Original file line number Diff line number Diff line change
@@ -20,21 +20,25 @@ import (

"github.com/copilot-extensions/rag-extension/copilot"
"github.com/copilot-extensions/rag-extension/embedding"
"github.com/google/go-github/v50/github"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
)

// Service provides and endpoint for this agent to perform chat completions
type Service struct {
pubKey *ecdsa.PublicKey

cache *cache.Cache
// Singleton
datasets []*embedding.Dataset
datasetsInit *sync.Once
}

func NewService(pubKey *ecdsa.PublicKey) *Service {
func NewService(pubKey *ecdsa.PublicKey, cache *cache.Cache) *Service {
return &Service{
pubKey: pubKey,
datasetsInit: &sync.Once{},
cache: cache,
}
}

@@ -64,6 +68,22 @@ func (s *Service) ChatCompletion(w http.ResponseWriter, r *http.Request) {
apiToken := r.Header.Get("X-GitHub-Token")
integrationID := r.Header.Get("Copilot-Integration-Id")

// retrieve GitHub handle
handle := s.resolveGitHubHandleViaToken(apiToken)
fmt.Printf("GitHub handle: %s\n", handle)
entraToken, found := s.cache.Get(handle)

// retrieve Entra token for GitHub Handle
if found {
fmt.Printf("Cache hit for %s: %s\n", handle, entraToken)
} else {
fmt.Printf("Cache miss for %s\n", handle)
}
//
// Now use entra token to authenticate against external API
// ...
//

var req *copilot.ChatRequest
if err := json.Unmarshal(body, &req); err != nil {
fmt.Printf("failed to unmarshal request: %v\n", err)
@@ -191,6 +211,26 @@ func (s *Service) generateCompletion(ctx context.Context, integrationID, apiToke
return nil
}

// resolveGitHubHandleViaToken retrieves the GitHub handle using the access token.
func (s *Service) resolveGitHubHandleViaToken(accessToken string) string {
client := github.NewClient(oauth2.NewClient(context.Background(), oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: accessToken},
)))

user, _, err := client.Users.Get(context.Background(), "")
if err != nil {
fmt.Printf("error retrieving GitHub user: %v", err)
return ""
}

return user.GetLogin()
}

func (s *Service) HelloWorld(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, World!"))
}

// asn1Signature is a struct for ASN.1 serializing/parsing signatures.
type asn1Signature struct {
R *big.Int
62 changes: 46 additions & 16 deletions config/info.go
Original file line number Diff line number Diff line change
@@ -14,17 +14,29 @@ type Info struct {
FQDN string

// ClientID comes from your configured GitHub app
ClientID string
GitHubClientID string

// ClientSecret comes from your configured GitHub app
ClientSecret string
GitHubClientSecret string

// ClientID comes from your configured EntraId app
EntraIdClientID string

// ClientSecret comes from your configured EntraId app
EntraIdClientSecret string

// TenantId comes from your configured EntraId app
EntraIdTenantId string
}

const (
portEnv = "PORT"
clientIdEnv = "CLIENT_ID"
clientSecretEnv = "CLIENT_SECRET"
fqdnEnv = "FQDN"
portEnv = "PORT"
entraIdClientIdEnv = "ENTRA_CLIENT_ID"
entraIdTenantEnv = "ENTRA_TENANT_ID"
entraIdClientSecretEnv = "ENTRA_CLIENT_SECRET"
fqdnEnv = "FQDN"
gitHubClientIdEnv = "GITHUB_CLIENT_ID"
gitHubClientSecretEnv = "GITHUB_CLIENT_SECRET"
)

func New() (*Info, error) {
@@ -38,20 +50,38 @@ func New() (*Info, error) {
return nil, fmt.Errorf("%s environment variable required", fqdnEnv)
}

clientID := os.Getenv(clientIdEnv)
if clientID == "" {
return nil, fmt.Errorf("%s environment variable required", clientIdEnv)
entraIdClientID := os.Getenv(entraIdClientIdEnv)
if entraIdClientID == "" {
return nil, fmt.Errorf("%s environment variable required", entraIdClientIdEnv)
}

entraIdClientSecret := os.Getenv(entraIdClientSecretEnv)
if entraIdClientSecret == "" {
return nil, fmt.Errorf("%s environment variable required", entraIdClientSecretEnv)
}

entraIdTenantId := os.Getenv(entraIdTenantEnv)
if entraIdTenantId == "" {
return nil, fmt.Errorf("%s environment variable required", entraIdTenantEnv)
}

gitHubClientID := os.Getenv(gitHubClientIdEnv)
if gitHubClientID == "" {
return nil, fmt.Errorf("%s environment variable required", gitHubClientIdEnv)
}

clientSecret := os.Getenv(clientSecretEnv)
if clientSecret == "" {
return nil, fmt.Errorf("%s environment variable required", clientSecretEnv)
gitHubClientSecret := os.Getenv(gitHubClientSecretEnv)
if gitHubClientSecret == "" {
return nil, fmt.Errorf("%s environment variable required", gitHubClientSecretEnv)
}

return &Info{
Port: port,
FQDN: fqdn,
ClientID: clientID,
ClientSecret: clientSecret,
Port: port,
FQDN: fqdn,
GitHubClientID: gitHubClientID,
GitHubClientSecret: gitHubClientSecret,
EntraIdClientID: entraIdClientID,
EntraIdClientSecret: entraIdClientSecret,
EntraIdTenantId: entraIdTenantId,
}, nil
}
17 changes: 10 additions & 7 deletions go.mod
Original file line number Diff line number Diff line change
@@ -3,17 +3,20 @@ module github.com/copilot-extensions/rag-extension
go 1.21.6

require (
github.com/google/go-github/v57 v57.0.0
github.com/google/uuid v1.6.0
github.com/invopop/jsonschema v0.12.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/oauth2 v0.22.0
)

require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/ProtonMail/go-crypto v0.0.0-20230217124315-7d5c6f04bbb8 // indirect
github.com/cloudflare/circl v1.1.0 // indirect
github.com/google/go-github/v50 v50.2.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
golang.org/x/crypto v0.7.0 // indirect
golang.org/x/sys v0.6.0 // indirect
)

require (
github.com/google/go-cmp v0.6.0 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible
)
44 changes: 21 additions & 23 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/ProtonMail/go-crypto v0.0.0-20230217124315-7d5c6f04bbb8 h1:wPbRQzjjwFc0ih8puEVAOFGELsn1zoIIYdxvML7mDxA=
github.com/ProtonMail/go-crypto v0.0.0-20230217124315-7d5c6f04bbb8/go.mod h1:I0gYDMZ6Z5GRU7l58bNFSkPTFN6Yl12dsUlAZ8xy98g=
github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
github.com/cloudflare/circl v1.1.0 h1:bZgT/A+cikZnKIwn7xL2OBj012Bmvho/o6RpRvv3GKY=
github.com/cloudflare/circl v1.1.0/go.mod h1:prBCrKB9DV4poKZY1l9zBXg2QJY7mvgRvtMxxK7fi4I=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs=
github.com/google/go-github/v57 v57.0.0/go.mod h1:s0omdnye0hvK/ecLvpsGfJMiRt85PimQh4oygmLIxHw=
github.com/google/go-github/v50 v50.2.0 h1:j2FyongEHlO9nxXLc+LP3wuBSVU9mVxfpdYUexMpIfk=
github.com/google/go-github/v50 v50.2.0/go.mod h1:VBY8FB6yPIjrtKhozXv4FQupxKLS6H4m6xFZlT43q8Q=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI=
github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A=
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/oauth2 v0.22.0 h1:BzDx2FehcG7jJwgWLELCdmLuxk2i+x9UDpSiss2u0ZA=
golang.org/x/oauth2 v0.22.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
35 changes: 29 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
@@ -10,19 +10,33 @@ import (
"net/url"
"os"
"strings"
"time"

"github.com/copilot-extensions/rag-extension/agent"
"github.com/copilot-extensions/rag-extension/config"
"github.com/copilot-extensions/rag-extension/oauth"

"github.com/patrickmn/go-cache"
)

var memCache *cache.Cache

func main() {

initCache()

if err := run(); err != nil {
fmt.Println(err)
os.Exit(1)
}
}

func initCache() {
// Create a cache with a default expiration time of 5 minutes, and which
// purges expired items every 10 minutes
memCache = cache.New(5*time.Minute, 10*time.Minute)
}

func run() error {
pubKey, err := fetchPublicKey()
if err != nil {
@@ -34,20 +48,29 @@ func run() error {
return fmt.Errorf("error fetching config: %w", err)
}

me, err := url.Parse(config.FQDN)
callbackFromGitHub, err := url.Parse(config.FQDN)
if err != nil {
return fmt.Errorf("unable to parse HOST environment variable: %w", err)
}

callbackFromGitHub.Path = "auth/callback/github"

callbackFromEntra, err := url.Parse(config.FQDN)
if err != nil {
return fmt.Errorf("unable to parse HOST environment variable: %w", err)
}

me.Path = "auth/callback"
callbackFromEntra.Path = "auth/callback/entra"

oauthService := oauth.NewService(config.ClientID, config.ClientSecret, me.String())
http.HandleFunc("/auth/authorization", oauthService.PreAuth)
http.HandleFunc("/auth/callback", oauthService.PostAuth)
oauthService := oauth.NewService(config.GitHubClientID, config.GitHubClientSecret, config.EntraIdClientID, config.EntraIdClientSecret, config.EntraIdTenantId, callbackFromGitHub.String(), callbackFromEntra.String(), memCache)
http.HandleFunc("/auth/authorization", oauthService.PreAuthGitHub)
http.HandleFunc("/auth/callback/github", oauthService.PostAuthGitHub)
http.HandleFunc("/auth/callback/entra", oauthService.PostAuthEntra)

agentService := agent.NewService(pubKey)
agentService := agent.NewService(pubKey, memCache)

http.HandleFunc("/agent", agentService.ChatCompletion)
http.HandleFunc("/", agentService.HelloWorld)

fmt.Println("Listening on port", config.Port)
return http.ListenAndServe(":"+config.Port, nil)
152 changes: 137 additions & 15 deletions oauth/handler.go
Original file line number Diff line number Diff line change
@@ -1,50 +1,84 @@
package oauth

import (
"context"
"fmt"
"net/http"

"github.com/google/go-github/v50/github"
"github.com/google/uuid"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"net/http"
)

// Service provides endpoints to allow this agent to be authorized.
type Service struct {
conf *oauth2.Config
confGitHub *oauth2.Config
confEntra *oauth2.Config
cache *cache.Cache
}

// resolveGitHubHandleViaToken retrieves the GitHub handle using the access token.
func (s *Service) resolveGitHubHandleViaToken(accessToken string) string {
client := github.NewClient(oauth2.NewClient(context.Background(), oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: accessToken},
)))

user, _, err := client.Users.Get(context.Background(), "")
if err != nil {
fmt.Printf("error retrieving GitHub user: %v", err)
return ""
}

return user.GetLogin()
}

func NewService(clientID, clientSecret, callback string) *Service {
func NewService(gitHubClientID, gitHubClientSecret, entraClientId, entraClientSecret, entraTenantId, callbackGitHub, callbackEntra string, cache *cache.Cache) *Service {
return &Service{
conf: &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: callback,
confGitHub: &oauth2.Config{
ClientID: gitHubClientID,
ClientSecret: gitHubClientSecret,
RedirectURL: callbackGitHub,
Endpoint: oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
},
},
confEntra: &oauth2.Config{
ClientID: entraClientId,
ClientSecret: entraClientSecret,
RedirectURL: callbackEntra,
Scopes: []string{"openid", "profile", "email"},
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", entraTenantId),
TokenURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", entraTenantId),
},
},
cache: cache,
}
}

const (
STATE_COOKIE = "oauth_state"
STATE_COOKIE_GITHUB = "oauth_state_github"
STATE_COOKIE_ENTRA = "oauth_state_entra"
GITHUB_ID_COOKIE = "github_handle"
)

// PreAuth is the landing page that the user arrives at when they first attempt
// to use the agent while unauthorized. You can do anything you want here,
// including making sure the user has an account on your side. At some point,
// you'll probably want to make a call to the authorize endpoint to authorize
// the app.
func (s *Service) PreAuth(w http.ResponseWriter, r *http.Request) {
func (s *Service) PreAuthGitHub(w http.ResponseWriter, r *http.Request) {
// In our example, we're not doing anything except going through the
// authorization flow. This is standard Oauth2.

verifier := oauth2.GenerateVerifier()
state := uuid.New()

url := s.conf.AuthCodeURL(state.String(), oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
url := s.confGitHub.AuthCodeURL(state.String(), oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier))
stateCookie := &http.Cookie{
Name: STATE_COOKIE,
Name: STATE_COOKIE_GITHUB,
Value: state.String(),
MaxAge: 10 * 60, // 10 minutes in seconds
Secure: true,
@@ -61,33 +95,121 @@ func (s *Service) PreAuth(w http.ResponseWriter, r *http.Request) {
// above, you can do anything you want here. A common thing you might do is
// get the user information and then perform some sort of account linking in
// your database.
func (s *Service) PostAuth(w http.ResponseWriter, r *http.Request) {
func (s *Service) PostAuthGitHub(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")
code := r.URL.Query().Get("code")

stateCookie, err := r.Cookie(STATE_COOKIE_GITHUB)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("state cookie not found"))
return
}

// Important: Compare the state! This prevents CSRF attacks
if state != stateCookie.Value {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("invalid state"))
return
}

githubToken, err := s.confGitHub.Exchange(r.Context(), code)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf("error exchange code for token: %v", err)))
return
}

// Response contains an GitHub access token
// TODO retrieve user id/information (via go-github)
github_handle := s.resolveGitHubHandleViaToken(githubToken.AccessToken) //"github_handle"

// Now do the same thing for Entra Authentication
s.PreAuthEntra(github_handle, w, r) // <- chain github id/information

// w.WriteHeader(http.StatusOK)
// w.Write([]byte("All done! Please return to the app"))
}

func (s *Service) PreAuthEntra(github_handle string, w http.ResponseWriter, r *http.Request) {
// In our example, we're not doing anything except going through the
// authorization flow. This is standard Oauth2.

// verifier := oauth2.GenerateVerifier()
state := uuid.New()

url := s.confEntra.AuthCodeURL(state.String(), oauth2.AccessTypeOnline)
stateCookie := &http.Cookie{
Name: STATE_COOKIE_ENTRA,
Value: state.String(),
MaxAge: 10 * 60, // 10 minutes in seconds
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}

// to chain github id/information back to the callback function
githubIdCookie := &http.Cookie{
Name: GITHUB_ID_COOKIE,
Value: github_handle,
MaxAge: 10 * 60, // 10 minutes in seconds
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}

http.SetCookie(w, stateCookie)
http.SetCookie(w, githubIdCookie)
w.Header().Set("location", url)
w.WriteHeader(http.StatusFound)
}

func (s *Service) PostAuthEntra(w http.ResponseWriter, r *http.Request) {
state := r.URL.Query().Get("state")
code := r.URL.Query().Get("code")

stateCookie, err := r.Cookie(STATE_COOKIE)
stateCookie, err := r.Cookie(STATE_COOKIE_ENTRA)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("state cookie not found"))
return
}

githubIdCookie, err := r.Cookie(GITHUB_ID_COOKIE)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("github id cookie not found"))
return
}

// Important: Compare the state! This prevents CSRF attacks
if state != stateCookie.Value {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("invalid state"))
return
}

_, err = s.conf.Exchange(r.Context(), code)
entraToken, err := s.confEntra.Exchange(r.Context(), code)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(fmt.Sprintf("error exchange code for token: %v", err)))
return
}
// map and store token in cache
s.cache.Set(githubIdCookie.Value, entraToken.AccessToken, cache.DefaultExpiration)

// Response contains an access token, now the world is your oyster. Get user information and perform account linking, or do whatever you want from here.
s.TestCache(githubIdCookie.Value)

w.WriteHeader(http.StatusOK)
w.Write([]byte("All done! Please return to the app"))
}

func (s *Service) TestCache(item string) {
// Debug out the values of the entraToken
result, found := s.cache.Get(item)
if found {
fmt.Printf("Cache hit for %s: %s\n", item, result)
} else {
fmt.Printf("Cache miss for %s\n", item)
}
}